diff --git a/.github/workflows/docker-build-deploy.yml b/.github/workflows/docker-build-deploy.yml index b75b99c..13b3659 100644 --- a/.github/workflows/docker-build-deploy.yml +++ b/.github/workflows/docker-build-deploy.yml @@ -287,33 +287,6 @@ jobs: mkdir -p "$BACKUP_DIR" resolve_db_config - # region agent log - echo "=== DEBUG: migration database url selection ===" - CANDIDATE_DATABASE_URL="$(sed -n 's/^DATABASE_URL=//p' "$CANDIDATE_ENV" | head -n 1)" - CANDIDATE_DATABASE_URL="$CANDIDATE_DATABASE_URL" \ - EFFECTIVE_MIGRATION_DATABASE_URL="$EFFECTIVE_MIGRATION_DATABASE_URL" \ - python3 - <<'PY' - import os - from urllib.parse import urlsplit - - def host_of(value: str) -> str: - if not value: - return "" - value = value.strip().strip('"').strip("'") - return urlsplit(value).hostname or "" - - print(f"candidate_database_host={host_of(os.environ.get('CANDIDATE_DATABASE_URL', ''))}") - print(f"migration_database_host={host_of(os.environ.get('EFFECTIVE_MIGRATION_DATABASE_URL', ''))}") - PY - - docker run --rm \ - --network "$NETWORK_NAME" \ - --env-file "$CANDIDATE_ENV" \ - -e MIGRATION_DATABASE_URL="$EFFECTIVE_MIGRATION_DATABASE_URL" \ - --entrypoint python \ - "$IMAGE_TAG" -c "import os; from urllib.parse import urlsplit; mig=os.getenv('MIGRATION_DATABASE_URL','').strip(); db=os.getenv('DATABASE_URL','').strip(); print('container_migration_database_host=' + (urlsplit(mig).hostname if mig else '')); print('container_database_host=' + (urlsplit(db).hostname if db else ''))" - # endregion agent log - if docker ps --format '{{.Names}}' | grep -qx "$API_CONTAINER"; then CURRENT_API_RUNNING=1 fi diff --git a/api/agents/conversation_agent.py b/api/agents/conversation_agent.py index 4f6e98c..6a9f55f 100644 --- a/api/agents/conversation_agent.py +++ b/api/agents/conversation_agent.py @@ -5,6 +5,7 @@ """ import json import logging +from datetime import datetime from typing import List, Optional, Dict, Any from langchain_core.messages import HumanMessage, AIMessage @@ -50,10 +51,17 @@ class ConversationAgent: role: str, content: str, message_type: str = "text", + voice_session_id: str | None = None, + timestamp: datetime | str | int | None = None, ): """保存消息到 Redis""" await redis_service.add_message( - conversation_id, role, content, message_type=message_type + conversation_id, + role, + content, + message_type=message_type, + voice_session_id=voice_session_id, + timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp, ) def _format_history_string(self, messages: List[Any]) -> str: @@ -245,6 +253,8 @@ class ConversationAgent: filled_fields: Dict[str, str], nickname: str = "", is_from_voice: bool = False, + voice_session_id: str | None = None, + user_message_timestamp: datetime | None = None, ) -> List[str]: """在资料收集过程中生成跟进回复""" if not self.llm: @@ -260,7 +270,14 @@ class ConversationAgent: response_text = response.content if hasattr(response, 'content') else str(response) human_msg_type = "audio" if is_from_voice else "text" - await self._save_message(conversation_id, "human", user_message, message_type=human_msg_type) + await self._save_message( + conversation_id, + "human", + user_message, + message_type=human_msg_type, + voice_session_id=voice_session_id, + timestamp=user_message_timestamp, + ) await self._save_message(conversation_id, "ai", response_text) messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] @@ -296,6 +313,8 @@ class ConversationAgent: memoir_state: MemoirStateSchema, user_profile_context: str = "", is_from_voice: bool = False, + voice_session_id: str | None = None, + user_message_timestamp: datetime | None = None, ) -> List[str]: """ 基于共享状态异步生成引导式回复 @@ -347,7 +366,14 @@ class ConversationAgent: response_text = response.content if hasattr(response, 'content') else str(response) human_msg_type = "audio" if is_from_voice else "text" - await self._save_message(conversation_id, "human", user_message, message_type=human_msg_type) + await self._save_message( + conversation_id, + "human", + user_message, + message_type=human_msg_type, + voice_session_id=voice_session_id, + timestamp=user_message_timestamp, + ) await self._save_message(conversation_id, "ai", response_text) messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] diff --git a/api/database/models.py b/api/database/models.py index 436ca3a..145e009 100644 --- a/api/database/models.py +++ b/api/database/models.py @@ -51,6 +51,7 @@ class Conversation(Base): id = Column(String, primary_key=True) user_id = Column(String, ForeignKey("users.id"), nullable=False) started_at = Column(DateTime(timezone=True), default=utc_now) + last_message_at = Column(DateTime(timezone=True), nullable=True) ended_at = Column(DateTime(timezone=True), nullable=True) duration_seconds = Column(Integer, default=0) summary = Column(Text, nullable=True) diff --git a/api/migrations/sync_schema_to_models.sql b/api/migrations/sync_schema_to_models.sql index fae8478..9ff15de 100644 --- a/api/migrations/sync_schema_to_models.sql +++ b/api/migrations/sync_schema_to_models.sql @@ -40,8 +40,19 @@ BEGIN ALTER TABLE conversations ADD COLUMN conversation_stage VARCHAR; RAISE NOTICE '已添加 conversations.conversation_stage'; END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'conversations' AND column_name = 'last_message_at') THEN + ALTER TABLE conversations ADD COLUMN last_message_at TIMESTAMP WITH TIME ZONE; + RAISE NOTICE '已添加 conversations.last_message_at'; + END IF; END $$; +UPDATE conversations +SET last_message_at = started_at +WHERE last_message_at IS NULL + AND started_at IS NOT NULL; + +CREATE INDEX IF NOT EXISTS ix_conversations_last_message_at ON conversations(last_message_at); + -- ========== 4. chapters 表缺失列 ========== DO $$ BEGIN diff --git a/api/routers/conversations.py b/api/routers/conversations.py index 6bf0235..21a49d0 100644 --- a/api/routers/conversations.py +++ b/api/routers/conversations.py @@ -6,7 +6,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Body from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select +from sqlalchemy import func, select import uuid from database import get_async_db, Conversation, Segment, User @@ -17,6 +17,65 @@ from database.models import User as UserModel router = APIRouter(prefix="/api/conversations", tags=["conversations"]) +def _datetime_to_timestamp_ms(value: datetime | None) -> int: + if value is None: + return int(datetime.now(timezone.utc).timestamp() * 1000) + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return int(value.timestamp() * 1000) + + +def _message_timestamp_ms(msg: dict, fallback: datetime | None) -> int: + raw_timestamp = msg.get("timestamp") + if isinstance(raw_timestamp, (int, float)): + return int(raw_timestamp) + if isinstance(raw_timestamp, str): + try: + return int(datetime.fromisoformat(raw_timestamp.replace("Z", "+00:00")).timestamp() * 1000) + except ValueError: + pass + return _datetime_to_timestamp_ms(fallback) + + +def _latest_message_time_ms(conversation: ConversationModel, history: list[dict]) -> int: + if conversation.last_message_at: + return _datetime_to_timestamp_ms(conversation.last_message_at) + if history: + return _message_timestamp_ms(history[-1], conversation.started_at) + return _datetime_to_timestamp_ms(conversation.started_at) + + +def _build_messages_from_history( + conversation_id: str, + history: list[dict], + fallback_timestamp: datetime | None, +) -> list[dict]: + messages: list[dict] = [] + seen_audio_sessions: set[str] = set() + + for idx, msg in enumerate(history): + role = msg.get("role") + message_type = msg.get("messageType", "text") + voice_session_id = msg.get("voiceSessionId") + if role == "human" and message_type == "audio" and voice_session_id: + if voice_session_id in seen_audio_sessions: + continue + seen_audio_sessions.add(voice_session_id) + + messages.append( + { + "id": f"{conversation_id}_msg_{idx}", + "conversationId": conversation_id, + "content": msg.get("content", ""), + "senderType": "user" if role == "human" else "assistant", + "timestamp": _message_timestamp_ms(msg, fallback_timestamp), + "messageType": message_type, + } + ) + + return messages + + @router.get("") async def get_conversations( current_user: UserModel = Depends(get_current_user), @@ -25,7 +84,7 @@ async def get_conversations( """获取当前用户的所有对话列表(需要认证)""" stmt = select(ConversationModel).where( ConversationModel.user_id == current_user.id - ).order_by(ConversationModel.started_at.desc()) + ).order_by(func.coalesce(ConversationModel.last_message_at, ConversationModel.started_at).desc()) result = await db.execute(stmt) conversations = result.scalars().all() @@ -35,11 +94,12 @@ async def get_conversations( for conv in conversations: # 从Redis获取最新消息预览 latest_message = None + history: list[dict] = [] try: history = await redis_service.get_conversation_history(conv.id) if history: latest_message = history[-1].get("content", "")[:50] # 取前50个字符 - except: + except Exception: pass conversation_list.append({ @@ -47,7 +107,7 @@ async def get_conversations( "title": conv.summary[:30] if conv.summary else "岁月知己", # 使用summary作为标题,如果没有则使用默认标题 "avatarUrl": None, "latestMessagePreview": latest_message or conv.summary, - "latestMessageTime": int(conv.started_at.timestamp() * 1000) if conv.started_at else int(datetime.now(timezone.utc).timestamp() * 1000), + "latestMessageTime": _latest_message_time_ms(conv, history), "unreadCount": 0, "isDefaultAssistant": conv.summary is None # 如果没有summary,则认为是默认助手 }) @@ -187,18 +247,12 @@ async def get_messages( from services.redis_service import redis_service try: history = await redis_service.get_conversation_history(conversation_id) - messages = [] - for idx, msg in enumerate(history): - messages.append({ - "id": f"{conversation_id}_msg_{idx}", - "conversationId": conversation_id, - "content": msg.get("content", ""), - "senderType": "user" if msg.get("role") == "human" else "assistant", - "timestamp": int(datetime.now(timezone.utc).timestamp() * 1000), # Redis中没有时间戳,使用当前时间 - "messageType": msg.get("messageType", "text"), # 保留语音消息类型,使重新进入时仍显示为语音条 - }) - return messages - except Exception as e: + return _build_messages_from_history( + conversation_id=conversation_id, + history=history, + fallback_timestamp=conversation.started_at, + ) + except Exception: # 如果Redis中没有数据,返回空列表 return [] diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 376d78d..4d6e718 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -35,7 +35,7 @@ class MessageType(str, Enum): AUDIO_CHUNK = "audio_chunk" AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传) AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音) - TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent,用于「转文字」发送 + TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent,只返回转写结果 TEXT = "text" # 文本消息 TRANSCRIPT = "transcript" # 语音转文字结果 AGENT_RESPONSE = "agent_response" @@ -148,6 +148,16 @@ class SegmentStreamState: active_tasks: Set[asyncio.Task] = field(default_factory=set) +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime: + activity_time = at or _utc_now() + conversation.last_message_at = activity_time + return activity_time + + def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str: if voice_session_id: return str(voice_session_id) @@ -183,6 +193,13 @@ def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int] return None +def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]: + scope = _extract_segment_scope(audio_url) + if scope: + return scope[0] + return None + + def _is_transcribe_failure(transcript_text: Optional[str]) -> bool: if not transcript_text: return True @@ -357,6 +374,7 @@ async def _process_audio_segment_async( processed=False, ) db.add(segment) + user_message_timestamp = _mark_conversation_active(conversation) await db.commit() await db.refresh(segment) await manager.background_runner.queue_message(conversation.user_id, segment.id) @@ -383,6 +401,7 @@ async def _process_audio_segment_async( db=db, manager=manager, user=user, + user_message_timestamp=ordered_segment.created_at or user_message_timestamp, ) break @@ -564,6 +583,7 @@ async def websocket_endpoint( processed=False ) db.add(segment) + user_message_timestamp = _mark_conversation_active(conversation) await db.commit() await db.refresh(segment) await manager.background_runner.queue_message(conversation.user_id, segment.id) @@ -576,8 +596,9 @@ async def websocket_endpoint( segment=segment, db=db, manager=manager, - user=user, - ) + user=user, + user_message_timestamp=segment.created_at or user_message_timestamp, + ) elif msg_type == MessageType.AUDIO_SEGMENT: # 处理分段语音消息(长语音持续上传) @@ -726,6 +747,7 @@ async def websocket_endpoint( processed=False ) db.add(segment) + user_message_timestamp = _mark_conversation_active(conversation) await db.commit() await db.refresh(segment) await manager.background_runner.queue_message(conversation.user_id, segment.id) @@ -740,6 +762,7 @@ async def websocket_endpoint( db=db, manager=manager, user=user, + user_message_timestamp=segment.created_at or user_message_timestamp, ) else: # 转写失败,发送错误消息 @@ -758,7 +781,7 @@ async def websocket_endpoint( }) elif msg_type == MessageType.TRANSCRIBE_ONLY: - # 仅转写:不落库、不触发 Agent,用于客户端「转文字」后发文本 + # 仅转写:不落库、不触发 Agent,只把识别结果返回给客户端 data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: @@ -906,6 +929,7 @@ async def process_user_message( db: AsyncSession, manager: ConnectionManager, user: UserModel = None, + user_message_timestamp: Optional[datetime] = None, ) -> None: """ 处理用户消息,生成Agent回应(异步版本) @@ -936,9 +960,12 @@ async def process_user_message( filled_fields=filled, nickname=user.nickname or "", is_from_voice=is_from_voice, + voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), + user_message_timestamp=user_message_timestamp, ) segment.agent_response = "\n\n".join(responses) + _mark_conversation_active(conversation) await db.commit() for i, response_text in enumerate(responses): @@ -987,9 +1014,12 @@ async def process_user_message( memoir_state=state, user_profile_context=user_profile_context, is_from_voice=is_from_voice, + voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), + user_message_timestamp=user_message_timestamp, ) segment.agent_response = "\n\n".join(responses) + _mark_conversation_active(conversation) await db.commit() for i, response_text in enumerate(responses): diff --git a/api/services/redis_service.py b/api/services/redis_service.py index ebb0000..0d0ae08 100644 --- a/api/services/redis_service.py +++ b/api/services/redis_service.py @@ -4,6 +4,7 @@ Redis 服务模块:用于会话状态存储和缓存 import os import json import logging +from datetime import datetime, timezone from typing import Optional, List, Dict, Any import redis.asyncio as aioredis @@ -77,6 +78,8 @@ class RedisService: role: str, content: str, message_type: str = "text", + voice_session_id: str | None = None, + timestamp: str | int | None = None, ) -> bool: """ 添加消息到对话历史 @@ -98,7 +101,15 @@ class RedisService: history = await self.get_conversation_history(conversation_id) # 添加新消息 - history.append({"role": role, "content": content, "messageType": message_type}) + item = { + "role": role, + "content": content, + "messageType": message_type, + "timestamp": timestamp or datetime.now(timezone.utc).isoformat(), + } + if voice_session_id: + item["voiceSessionId"] = voice_session_id + history.append(item) # 保存回 Redis(带过期时间) await client.setex(key, self.session_ttl, json.dumps(history, ensure_ascii=False)) diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py index 905defb..52eb46c 100644 --- a/api/tasks/memoir_tasks.py +++ b/api/tasks/memoir_tasks.py @@ -923,6 +923,9 @@ def generate_chapter_images(self, chapter_id: str): current_item["status"] = IMAGE_STATUS_COMPLETED current_item["error"] = None current_item["retryable"] = None + current_item["updated_at"] = datetime.now(timezone.utc).isoformat() + _apply_item_to_memoir_image(section.image_record, current_item) + db.commit() logger.info( "章节补图成功: chapter=%s, section_index=%s, url=%s", chapter_id, diff --git a/api/tests/test_conversation_messages_history.py b/api/tests/test_conversation_messages_history.py new file mode 100644 index 0000000..2ef540e --- /dev/null +++ b/api/tests/test_conversation_messages_history.py @@ -0,0 +1,117 @@ +import unittest +from datetime import datetime, timezone + +from database.models import Conversation + +from routers import conversations as conversations_router + + +class ConversationMessagesHistoryTest(unittest.TestCase): + def test_build_messages_collapses_audio_segments_from_same_voice_session(self): + history = [ + { + "role": "human", + "content": "第一段", + "messageType": "audio", + "voiceSessionId": "voice-1", + "timestamp": "2026-03-14T12:00:01+00:00", + }, + { + "role": "ai", + "content": "继续说", + "messageType": "text", + "timestamp": "2026-03-14T12:00:02+00:00", + }, + { + "role": "human", + "content": "第二段", + "messageType": "audio", + "voiceSessionId": "voice-1", + "timestamp": "2026-03-14T12:00:03+00:00", + }, + { + "role": "ai", + "content": "我记住了", + "messageType": "text", + "timestamp": "2026-03-14T12:00:04+00:00", + }, + ] + + messages = conversations_router._build_messages_from_history( + conversation_id="conv-1", + history=history, + fallback_timestamp=datetime(2026, 3, 14, 12, 0, 0, tzinfo=timezone.utc), + ) + + self.assertEqual( + [(msg["senderType"], msg["messageType"], msg["content"]) for msg in messages], + [ + ("user", "audio", "第一段"), + ("assistant", "text", "继续说"), + ("assistant", "text", "我记住了"), + ], + ) + self.assertEqual(messages[0]["timestamp"], 1773489601000) + self.assertEqual(messages[1]["timestamp"], 1773489602000) + self.assertEqual(messages[2]["timestamp"], 1773489604000) + + def test_build_messages_keeps_distinct_voice_sessions_separate(self): + history = [ + { + "role": "human", + "content": "第一次录音", + "messageType": "audio", + "voiceSessionId": "voice-1", + "timestamp": "2026-03-14T12:00:01+00:00", + }, + { + "role": "human", + "content": "第二次录音", + "messageType": "audio", + "voiceSessionId": "voice-2", + "timestamp": "2026-03-14T12:00:02+00:00", + }, + ] + + messages = conversations_router._build_messages_from_history( + conversation_id="conv-1", + history=history, + fallback_timestamp=datetime(2026, 3, 14, 12, 0, 0, tzinfo=timezone.utc), + ) + + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0]["messageType"], "audio") + self.assertEqual(messages[0]["content"], "第一次录音") + self.assertEqual(messages[1]["messageType"], "audio") + self.assertEqual(messages[1]["content"], "第二次录音") + + def test_latest_message_time_prefers_conversation_last_message_at(self): + conversation = Conversation( + id="conv-1", + user_id="user-1", + started_at=datetime(2026, 3, 9, 12, 0, 0, tzinfo=timezone.utc), + last_message_at=datetime(2026, 3, 14, 12, 0, 5, tzinfo=timezone.utc), + ) + history = [ + { + "role": "human", + "content": "旧消息", + "messageType": "text", + "timestamp": "2026-03-10T12:00:00+00:00", + } + ] + + latest_message_time = conversations_router._latest_message_time_ms(conversation, history) + + self.assertEqual(latest_message_time, 1773489605000) + + def test_message_timestamp_falls_back_to_started_at_for_legacy_history(self): + conversation = Conversation( + id="conv-1", + user_id="user-1", + started_at=datetime(2026, 3, 14, 12, 0, 0, tzinfo=timezone.utc), + ) + + timestamp = conversations_router._message_timestamp_ms({}, conversation.started_at) + + self.assertEqual(timestamp, 1773489600000) diff --git a/api/tests/test_generate_chapter_images_persistence.py b/api/tests/test_generate_chapter_images_persistence.py new file mode 100644 index 0000000..64f4663 --- /dev/null +++ b/api/tests/test_generate_chapter_images_persistence.py @@ -0,0 +1,102 @@ +import base64 +import unittest +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from tasks.memoir_tasks import generate_chapter_images + + +_ONE_BY_ONE_PNG = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+aF9sAAAAASUVORK5CYII=" +) + + +def _image_record(img_dict): + d = dict(img_dict or {}) + return SimpleNamespace( + order_index=d.get("index", 0), + placeholder=d.get("placeholder"), + description=d.get("description"), + status=d.get("status"), + prompt=d.get("prompt"), + url=d.get("url"), + storage_key=d.get("storage_key"), + provider=d.get("provider"), + style=d.get("style"), + size=d.get("size"), + error=d.get("error"), + retryable=d.get("retryable"), + created_at=d.get("created_at"), + updated_at=d.get("updated_at"), + ) + + +def _chapter_stub(): + rec = _image_record( + { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + } + ) + section = SimpleNamespace( + content="那条路我一直记得。", + image_id="image-1", + image_record=rec, + order_index=0, + ) + return SimpleNamespace( + id="chapter-1", + user_id="user-1", + title="童年的夏天", + category="childhood", + sections=[section], + images=[], + cover_image=None, + ) + + +class GenerateChapterImagesPersistenceTest(unittest.TestCase): + @patch("tasks.memoir_tasks.SessionLocal") + @patch("tasks.memoir_tasks.TencentCosStorageService") + @patch("tasks.memoir_tasks.LiblibImageProvider") + @patch("tasks.memoir_tasks.MemoirImagePromptService") + @patch("tasks.memoir_tasks._release_chapter_image_lock") + @patch("tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True) + def test_successful_generation_persists_completed_status( + self, + _acquire_lock_mock, + _release_lock_mock, + prompt_service_cls, + provider_cls, + storage_cls, + session_local_cls, + ): + chapter = _chapter_stub() + db = Mock() + db.execute.return_value.unique.return_value.scalar_one_or_none.return_value = chapter + session_local_cls.return_value = db + + prompt_service_cls.return_value.build_prompt.return_value = { + "prompt": "A serene southern China town", + "style": "watercolor", + "size": "1024x1024", + } + provider_inst = provider_cls.return_value + provider_inst.submit_generation.return_value = { + "status": "completed", + "image_url": "https://provider.example.com/1.png", + } + provider_inst.download_image.return_value = _ONE_BY_ONE_PNG + storage_cls.from_env.return_value.upload_bytes.return_value = ( + "https://cos.example.com/memoirs/user-1/chapter-1/0.png" + ) + + generate_chapter_images.run("chapter-1") + + record = chapter.sections[0].image_record + self.assertEqual(record.status, "completed") + self.assertEqual(record.url, "https://cos.example.com/memoirs/user-1/chapter-1/0.png") + self.assertEqual(record.prompt, "A serene southern China town") diff --git a/api/tests/test_generate_chapter_images_task.py b/api/tests/test_generate_chapter_images_task.py index 46f2ccf..5af4a34 100644 --- a/api/tests/test_generate_chapter_images_task.py +++ b/api/tests/test_generate_chapter_images_task.py @@ -13,14 +13,20 @@ def _section_image_record(img_dict): """把图片 dict 转成 image_record 用的 SimpleNamespace(可被任务更新属性)。""" d = dict(img_dict or {}) return SimpleNamespace( + order_index=d.get("index", 0), placeholder=d.get("placeholder"), description=d.get("description"), status=d.get("status"), prompt=d.get("prompt"), url=d.get("url"), storage_key=d.get("storage_key"), + provider=d.get("provider"), + style=d.get("style"), + size=d.get("size"), error=d.get("error"), retryable=d.get("retryable"), + created_at=d.get("created_at"), + updated_at=d.get("updated_at"), ) @@ -217,7 +223,7 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): prompt_service_cls.assert_not_called() provider_cls.assert_not_called() storage_cls.from_env.return_value.upload_bytes.assert_not_called() - db.commit.assert_called_once() + db.commit.assert_not_called() @patch("api.tasks.memoir_tasks.SessionLocal") @patch("api.tasks.memoir_tasks.TencentCosStorageService") diff --git a/api/tests/test_websocket_baseline.py b/api/tests/test_websocket_baseline.py index a995b70..011c217 100644 --- a/api/tests/test_websocket_baseline.py +++ b/api/tests/test_websocket_baseline.py @@ -210,6 +210,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(segments), 1) self.assertEqual(segments[0].transcript_text, "你好") self.assertIsNone(segments[0].audio_url) + self.assertIsNotNone(conversation.last_message_at) fake_manager.background_runner.queue_message.assert_awaited_once() process_user_message_mock.assert_awaited_once() @@ -269,6 +270,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(segments), 1) self.assertEqual(segments[0].transcript_text, "这是转写结果") self.assertEqual(segments[0].audio_url, "audio:12s") + self.assertIsNotNone(conversation.last_message_at) transcript_msgs = [ item["message"] diff --git a/app-android/app/build.gradle.kts b/app-android/app/build.gradle.kts index d496710..172a998 100644 --- a/app-android/app/build.gradle.kts +++ b/app-android/app/build.gradle.kts @@ -72,7 +72,7 @@ android { // 默认调试即连公网(https://lifecho.worldsplats.com),便于公网调试 debug { buildConfigField("Boolean", "IS_DEBUG_MODE", "true") - buildConfigField("Boolean", "USE_PROD_SERVER", "false") + buildConfigField("Boolean", "USE_PROD_SERVER", "true") applicationIdSuffix = ".debug" versionNameSuffix = "-debug" } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/data/database/ConversationDao.kt b/app-android/app/src/main/java/com/huaga/life_echo/data/database/ConversationDao.kt index b7bfdf7..d826e1d 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/data/database/ConversationDao.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/data/database/ConversationDao.kt @@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.Flow @Dao interface ConversationDao { - @Query("SELECT * FROM conversations ORDER BY startedAt DESC") + @Query("SELECT * FROM conversations ORDER BY COALESCE(latestMessageTime, startedAt) DESC") fun getAllConversations(): Flow> @Query("SELECT * FROM conversations WHERE id = :id") diff --git a/app-android/app/src/main/java/com/huaga/life_echo/data/database/VoiceAttachmentDao.kt b/app-android/app/src/main/java/com/huaga/life_echo/data/database/VoiceAttachmentDao.kt index 81e33ec..7d6b2e9 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/data/database/VoiceAttachmentDao.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/data/database/VoiceAttachmentDao.kt @@ -16,6 +16,9 @@ interface VoiceAttachmentDao { @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insert(attachment: VoiceAttachment) + @Query("DELETE FROM voice_attachments WHERE conversationId = :conversationId AND userMessageIndex = :index") + suspend fun deleteByConversationAndIndex(conversationId: String, index: Int) + @Query("DELETE FROM voice_attachments WHERE conversationId = :conversationId") suspend fun deleteByConversationId(conversationId: String) } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/data/repository/ConversationRepository.kt b/app-android/app/src/main/java/com/huaga/life_echo/data/repository/ConversationRepository.kt index f59f005..4e29a44 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/data/repository/ConversationRepository.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/data/repository/ConversationRepository.kt @@ -31,6 +31,20 @@ class ConversationRepository( suspend fun updateConversation(conversation: Conversation) { conversationDao.updateConversation(conversation) } + + suspend fun touchLatestMessage( + conversationId: String, + latestMessagePreview: String, + latestMessageTime: Long, + ) { + val existing = conversationDao.getConversationById(conversationId) ?: return + conversationDao.updateConversation( + existing.copy( + latestMessagePreview = latestMessagePreview, + latestMessageTime = latestMessageTime, + ) + ) + } suspend fun deleteConversation(id: String): Result { val result = conversationApi.deleteConversation(id) @@ -112,19 +126,32 @@ class ConversationRepository( onSuccess = { conversations -> // 将DTO转换为Entity并保存到数据库 conversations.forEach { dto -> + val existing = conversationDao.getConversationById(dto.id) + val keepLocalLatest = + (existing?.latestMessageTime ?: Long.MIN_VALUE) > dto.latestMessageTime + val resolvedLatestPreview = if (keepLocalLatest) { + existing?.latestMessagePreview ?: dto.latestMessagePreview + } else { + dto.latestMessagePreview + } + val resolvedLatestTime = if (keepLocalLatest) { + existing?.latestMessageTime ?: dto.latestMessageTime + } else { + dto.latestMessageTime + } val conversation = Conversation( id = dto.id, - userId = "", // 需要从TokenManager获取 - startedAt = dto.latestMessageTime, - endedAt = null, - durationSeconds = 0, - summary = dto.latestMessagePreview, - currentTopic = null, - conversationStage = null, + userId = existing?.userId ?: "", // 需要从TokenManager获取 + startedAt = existing?.startedAt ?: dto.latestMessageTime, + endedAt = existing?.endedAt, + durationSeconds = existing?.durationSeconds ?: 0, + summary = resolvedLatestPreview, + currentTopic = existing?.currentTopic, + conversationStage = existing?.conversationStage, avatarUrl = dto.avatarUrl, title = dto.title, - latestMessagePreview = dto.latestMessagePreview, - latestMessageTime = dto.latestMessageTime + latestMessagePreview = resolvedLatestPreview, + latestMessageTime = resolvedLatestTime ) conversationDao.insertConversation(conversation) } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/data/repository/MessageRepository.kt b/app-android/app/src/main/java/com/huaga/life_echo/data/repository/MessageRepository.kt index 5291adc..7092a0c 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/data/repository/MessageRepository.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/data/repository/MessageRepository.kt @@ -44,4 +44,8 @@ class MessageRepository( suspend fun insertMessages(messages: List) { messageDao.insertMessages(messages) } + + suspend fun deleteMessageById(id: String) { + messageDao.getMessageById(id)?.let { messageDao.deleteMessage(it) } + } } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/data/repository/VoiceAttachmentRepository.kt b/app-android/app/src/main/java/com/huaga/life_echo/data/repository/VoiceAttachmentRepository.kt index 6c1fe48..75a56c8 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/data/repository/VoiceAttachmentRepository.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/data/repository/VoiceAttachmentRepository.kt @@ -18,6 +18,10 @@ class VoiceAttachmentRepository( voiceAttachmentDao.insert(attachment) } + suspend fun deleteByConversationAndIndex(conversationId: String, userMessageIndex: Int) { + voiceAttachmentDao.deleteByConversationAndIndex(conversationId, userMessageIndex) + } + suspend fun deleteByConversationId(conversationId: String) { voiceAttachmentDao.deleteByConversationId(conversationId) } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/RecordingDurationEstimator.kt b/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/RecordingDurationEstimator.kt new file mode 100644 index 0000000..920e265 --- /dev/null +++ b/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/RecordingDurationEstimator.kt @@ -0,0 +1,30 @@ +package com.huaga.life_echo.feature.voice + +import kotlin.math.ceil + +internal object RecordingDurationEstimator { + + fun elapsedSeconds(startElapsedMs: Long, nowElapsedMs: Long): Int { + if (startElapsedMs <= 0L || nowElapsedMs <= startElapsedMs) { + return 0 + } + return ((nowElapsedMs - startElapsedMs) / 1000L).toInt() + } + + fun metadataSeconds(durationMs: Long): Int { + if (durationMs <= 0L) { + return 0 + } + return ceil(durationMs / 1000.0).toInt() + } + + fun finalDurationSeconds( + startElapsedMs: Long, + stopElapsedMs: Long, + metadataDurationMs: Long?, + ): Int { + val elapsed = elapsedSeconds(startElapsedMs = startElapsedMs, nowElapsedMs = stopElapsedMs) + val metadata = metadataDurationMs?.let(::metadataSeconds) ?: 0 + return maxOf(elapsed, metadata) + } +} diff --git a/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/VoiceRecorder.kt b/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/VoiceRecorder.kt index c494eaa..382b0a7 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/VoiceRecorder.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/feature/voice/VoiceRecorder.kt @@ -4,6 +4,7 @@ import android.content.Context import android.media.MediaMetadataRetriever import android.media.MediaRecorder import android.os.Build +import android.os.SystemClock import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -22,6 +23,7 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { private val recorderLock = Any() private var mediaRecorder: MediaRecorder? = null private var durationJob: Job? = null + private var recordingStartedAtElapsedMs: Long = 0L private val scope = CoroutineScope(Dispatchers.Main) private val _isRecording = MutableStateFlow(false) @@ -68,6 +70,7 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { _isRecording.value = true _recordingDuration.value = 0 + recordingStartedAtElapsedMs = SystemClock.elapsedRealtime() startDurationTimer() @@ -95,13 +98,17 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { ?: return RecorderStopResult.Failure(IllegalStateException("当前没有录音实例")) stopDurationTimer() + val stopElapsedMs = SystemClock.elapsedRealtime() try { recorder.apply { stop() } RecorderStopResult.Success( - durationSeconds = getAudioDuration(filePath), + durationSeconds = resolveFinalDurationSeconds( + filePath = filePath, + stopElapsedMs = stopElapsedMs, + ), ) } catch (e: Exception) { e.printStackTrace() @@ -111,6 +118,7 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { mediaRecorder = null _isRecording.value = false _recordingDuration.value = 0 + recordingStartedAtElapsedMs = 0L } } } @@ -135,6 +143,7 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { mediaRecorder = null _isRecording.value = false _recordingDuration.value = 0 + recordingStartedAtElapsedMs = 0L } } } @@ -142,15 +151,22 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { /** * 获取音频文件时长(秒) */ - private fun getAudioDuration(filePath: String): Int { + private fun resolveFinalDurationSeconds(filePath: String, stopElapsedMs: Long): Int { + return RecordingDurationEstimator.finalDurationSeconds( + startElapsedMs = recordingStartedAtElapsedMs, + stopElapsedMs = stopElapsedMs, + metadataDurationMs = getAudioDurationMs(filePath), + ) + } + + private fun getAudioDurationMs(filePath: String): Long? { val retriever = MediaMetadataRetriever() return try { retriever.setDataSource(filePath) - val durationMs = retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_DURATION)?.toLongOrNull() ?: 0 - (durationMs / 1000).toInt() + retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_DURATION)?.toLongOrNull() } catch (e: Exception) { e.printStackTrace() - _recordingDuration.value + null } finally { retriever.release() } @@ -164,6 +180,7 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { mediaRecorder = null _isRecording.value = false _recordingDuration.value = 0 + recordingStartedAtElapsedMs = 0L } } @@ -171,8 +188,11 @@ class VoiceRecorder(private val context: Context) : RecorderEngine { stopDurationTimer() durationJob = scope.launch { while (isActive && _isRecording.value) { - delay(1000) - val newDuration = _recordingDuration.value + 1 + delay(200) + val newDuration = RecordingDurationEstimator.elapsedSeconds( + startElapsedMs = recordingStartedAtElapsedMs, + nowElapsedMs = SystemClock.elapsedRealtime(), + ) _recordingDuration.value = newDuration if (newDuration >= recordingLimit) { onRecordingLimitReached?.invoke() diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/components/chat/MessageList.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/components/chat/MessageList.kt index e6415dd..6dcad75 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/components/chat/MessageList.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/components/chat/MessageList.kt @@ -33,6 +33,10 @@ import com.huaga.life_echo.utils.TimeUtils import kotlinx.coroutines.delay import kotlinx.coroutines.launch +internal fun shouldRenderAudioMessage(filePath: String, durationSeconds: Int): Boolean { + return filePath.isNotBlank() && durationSeconds > 0 +} + /** * 消息列表组件 * @@ -160,14 +164,18 @@ fun MessageList( playbackInfo.state == PlaybackState.PLAYING val progress = if (playbackInfo.currentMessageId == message.id) playbackInfo.progress else 0f - - UserAudioMessageBubble( - messageId = message.id, - duration = duration, - isPlaying = isPlaying, - playbackProgress = progress, - onPlayClick = { onAudioPlayClick(message.id, filePath) } - ) + + if (shouldRenderAudioMessage(filePath = filePath, durationSeconds = duration)) { + UserAudioMessageBubble( + messageId = message.id, + duration = duration, + isPlaying = isPlaying, + playbackProgress = progress, + onPlayClick = { onAudioPlayClick(message.id, filePath) } + ) + } else { + UserMessageBubble(text = message.content) + } } else { UserMessageBubble(text = message.content) } @@ -183,14 +191,18 @@ fun MessageList( playbackInfo.state == PlaybackState.PLAYING val progress = if (playbackInfo.currentMessageId == message.id) playbackInfo.progress else 0f - - AIAudioMessageBubble( - messageId = message.id, - duration = duration, - isPlaying = isPlaying, - playbackProgress = progress, - onPlayClick = { onAudioPlayClick(message.id, filePath) } - ) + + if (shouldRenderAudioMessage(filePath = filePath, durationSeconds = duration)) { + AIAudioMessageBubble( + messageId = message.id, + duration = duration, + isPlaying = isPlaying, + playbackProgress = progress, + onPlayClick = { onAudioPlayClick(message.id, filePath) } + ) + } else { + AIMessageBubble(text = message.content) + } } } else { // 文本消息 - 在 [SPLIT] 处分割消息,显示为多个气泡 diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/CreateMemoryScreen.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/CreateMemoryScreen.kt index 56f3991..c7be1df 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/CreateMemoryScreen.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/CreateMemoryScreen.kt @@ -49,7 +49,6 @@ fun CreateMemoryScreen( val isRecording by viewModel.isRecording.collectAsState() val transcript by viewModel.transcript.collectAsState() val connectionStatus by viewModel.connectionStatus.collectAsState() - val userMessages by viewModel.userMessages.collectAsState() val historyMessages by viewModel.historyMessages.collectAsState() val isStreaming by viewModel.isStreaming.collectAsState() val streamingText by viewModel.streamingText.collectAsState() @@ -125,27 +124,7 @@ fun CreateMemoryScreen( // 构建消息列表(包含历史消息和当前消息) // 注意:AI回复已经直接添加到 historyMessages 中,不需要额外处理 agentResponse - val messages = remember(historyMessages, userMessages) { - buildList { - // 先添加历史消息 - addAll(historyMessages) - - // 添加当前会话的用户消息(排除已存在的) - val existingUserMessageIds = historyMessages.filter { it.senderType == "user" }.map { it.content }.toSet() - userMessages.forEachIndexed { index, text -> - if (!existingUserMessageIds.contains(text)) { - add(MessageDto( - id = "user_${historyMessages.size + index}", - conversationId = conversationId, - content = text, - senderType = "user", - timestamp = System.currentTimeMillis() - (userMessages.size - index) * 1000L, - messageType = "text" - )) - } - } - }.sortedBy { it.timestamp } // 按时间排序 - } + val messages = remember(historyMessages) { historyMessages.sortedBy { it.timestamp } } Scaffold( snackbarHost = { SnackbarHost(snackbarHostState) }, @@ -233,9 +212,7 @@ fun CreateMemoryScreen( viewModel.startRecordingVoice() } }, - onStopRecording = { - viewModel.stopAndSendRecording() - }, + onStopRecording = { viewModel.stopAndSendRecording() }, onCancelRecording = { viewModel.cancelRecordingVoice() }, diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/ConversationListViewModel.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/ConversationListViewModel.kt index 5805237..df117a9 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/ConversationListViewModel.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/ConversationListViewModel.kt @@ -8,6 +8,7 @@ import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.launch @@ -17,6 +18,9 @@ class ConversationListViewModel( // 本地数据库的对话列表 val conversations = conversationRepository.getAllConversations() + .map { items -> + items.sortedByDescending { it.latestMessageTime ?: it.startedAt } + } .stateIn( scope = viewModelScope, started = SharingStarted.WhileSubscribed(5000), diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt index 29d5b2c..6ef3277 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModel.kt @@ -23,7 +23,10 @@ import com.huaga.life_echo.network.WebSocketMessage import com.huaga.life_echo.network.MessageType import com.huaga.life_echo.model.MessageDto import com.huaga.life_echo.data.database.Chapter +import com.huaga.life_echo.data.database.Message import com.huaga.life_echo.data.database.VoiceAttachment +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -32,10 +35,12 @@ import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch import kotlinx.coroutines.delay import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.withContext import java.io.File class CreateMemoryViewModel( @@ -48,6 +53,7 @@ class CreateMemoryViewModel( private val conversationRealtime: ConversationRealtimePort, private val tokenInitializer: (Context) -> Unit = TokenManager::initialize, private val recordingCoordinator: RecordingCoordinator, + private val ioDispatcher: CoroutineDispatcher = Dispatchers.IO, ) : ViewModel() { companion object { @@ -82,7 +88,6 @@ class CreateMemoryViewModel( val agentResponse = MutableStateFlow("") val connectionStatus = MutableStateFlow("未连接") val conversationId = MutableStateFlow(null) - val userMessages = MutableStateFlow>(emptyList()) val historyMessages = MutableStateFlow>(emptyList()) @@ -115,11 +120,85 @@ class CreateMemoryViewModel( @Volatile private var waitingForTranscribeOnly = false private var taskPollingJob: Job? = null + private var localMessagesJob: Job? = null init { tokenInitializer(context) } + private fun bindLocalConversationMessages(convId: String) { + localMessagesJob?.cancel() + historyMessages.value = emptyList() + _audioFilePaths.value = emptyMap() + _audioDurations.value = emptyMap() + + localMessagesJob = viewModelScope.launch { + messageRepository.getMessagesByConversationId(convId).collectLatest { messages -> + val mappedMessages = messages.map { message -> + MessageDto( + id = message.id, + conversationId = message.conversationId, + content = message.content, + senderType = message.senderType, + timestamp = message.timestamp, + messageType = message.messageType, + ) + } + historyMessages.value = mappedMessages + refreshLocalVoiceAttachments(convId, mappedMessages) + } + } + } + + private suspend fun refreshLocalVoiceAttachments(convId: String, messages: List) { + val attachments = withContext(ioDispatcher) { + voiceAttachmentRepository.getByConversationId(convId) + } + if (attachments.isEmpty()) { + _audioFilePaths.value = emptyMap() + _audioDurations.value = emptyMap() + return + } + + val attachmentMap = attachments.associateBy { it.userMessageIndex } + var userMsgIndex = 0 + val newPaths = mutableMapOf() + val newDurations = mutableMapOf() + + for (msg in messages) { + if (msg.senderType != "user") continue + if (msg.messageType != "audio") { + userMsgIndex++ + continue + } + val att = attachmentMap[userMsgIndex] + if (att == null) { + userMsgIndex++ + continue + } + if (File(att.filePath).exists()) { + newPaths[msg.id] = att.filePath + newDurations[msg.id] = att.durationSeconds + } + userMsgIndex++ + } + + _audioFilePaths.value = newPaths + _audioDurations.value = newDurations + } + + private suspend fun updateConversationLatestMessage( + conversationId: String, + messagePreview: String, + timestamp: Long, + ) { + conversationRepository.touchLatestMessage( + conversationId = conversationId, + latestMessagePreview = messagePreview, + latestMessageTime = timestamp, + ) + } + /** * 确保 WebSocket 已连接。如果未连接则尝试重连,并等待 state 变为 Connected(带超时)。 * @return true 表示已连接,false 表示连接失败 @@ -154,6 +233,7 @@ class CreateMemoryViewModel( viewModelScope.launch { conversationId.value = convId connectionStatus.value = "连接中..." + bindLocalConversationMessages(convId) try { loadHistoryMessages(convId) @@ -184,55 +264,14 @@ class CreateMemoryViewModel( private suspend fun loadHistoryMessages(convId: String) { val result = conversationApi.getMessages(convId) result.fold( - onSuccess = { messages -> - historyMessages.value = messages + onSuccess = { messageRepository.syncMessages(convId) - _audioFilePaths.value = emptyMap() - _audioDurations.value = emptyMap() - mergeLocalVoiceAttachments(convId, messages) }, onFailure = { exception -> connectionStatus.value = "加载历史消息失败: ${exception.message}" } ) } - - /** - * 将本地持久化的语音附件合并到 audioFilePaths 和 audioDurations, - * 使重新进入时语音条显示与首次一致(含时长、可播放) - */ - private suspend fun mergeLocalVoiceAttachments(convId: String, messages: List) { - val attachments = voiceAttachmentRepository.getByConversationId(convId) - if (attachments.isEmpty()) return - - val attachmentMap = attachments.associateBy { it.userMessageIndex } - var userMsgIndex = 0 - val newPaths = mutableMapOf() - val newDurations = mutableMapOf() - - for (msg in messages) { - if (msg.senderType != "user") continue - if (msg.messageType != "audio") { - userMsgIndex++ - continue - } - val att = attachmentMap[userMsgIndex] - if (att == null) { - userMsgIndex++ - continue - } - if (File(att.filePath).exists()) { - newPaths[msg.id] = att.filePath - newDurations[msg.id] = att.durationSeconds - } - userMsgIndex++ - } - - if (newPaths.isNotEmpty() || newDurations.isNotEmpty()) { - _audioFilePaths.value = _audioFilePaths.value + newPaths - _audioDurations.value = _audioDurations.value + newDurations - } - } fun startConversation() { viewModelScope.launch { @@ -253,14 +292,11 @@ class CreateMemoryViewModel( isRecording.value = true connectionStatus.value = "连接中..." - historyMessages.value = emptyList() - userMessages.value = emptyList() streamingText.value = "" agentResponse.value = "" isStreaming.value = false isTyping.value = false - _audioFilePaths.value = emptyMap() - _audioDurations.value = emptyMap() + bindLocalConversationMessages(convId) conversationApi.clearTasks() @@ -381,7 +417,9 @@ class CreateMemoryViewModel( return@launch } val audioBytes = try { - File(capture.session.filePath).readBytes() + withContext(ioDispatcher) { + File(capture.session.filePath).readBytes() + } } catch (e: Exception) { Log.e(TAG, "读取录音文件失败: ${e.message}") return@launch @@ -400,7 +438,8 @@ class CreateMemoryViewModel( val text = withTimeoutOrNull(15000L) { pendingTranscribeChannel.receive() } waitingForTranscribeOnly = false if (!text.isNullOrBlank() && !text.startsWith("转写失败")) { - sendTextMessage(text) + transcript.value = text + connectionStatus.value = "转写完成" } else { connectionStatus.value = "转写失败或超时,请重试" } @@ -436,38 +475,46 @@ class CreateMemoryViewModel( val id = conversationId.value ?: return val userMessageIndex = historyMessages.value.count { it.senderType == "user" } - val persistentPath = persistVoiceFile(id, userMessageIndex, filePath) - if (persistentPath != null) { - voiceAttachmentRepository.insert( - VoiceAttachment( - conversationId = id, - userMessageIndex = userMessageIndex, - filePath = persistentPath, - durationSeconds = durationSeconds, + val persistentPath = withContext(ioDispatcher) { + val path = persistVoiceFile(id, userMessageIndex, filePath) + if (path != null) { + voiceAttachmentRepository.insert( + VoiceAttachment( + conversationId = id, + userMessageIndex = userMessageIndex, + filePath = path, + durationSeconds = durationSeconds, + ) ) - ) + } + path } val tempMessageId = "audio_user_${System.currentTimeMillis()}" - val tempMessage = MessageDto( + val messageTimestamp = System.currentTimeMillis() + val tempMessage = Message( id = tempMessageId, conversationId = id, content = "[语音消息]", senderType = "user", - timestamp = System.currentTimeMillis(), + timestamp = messageTimestamp, messageType = "audio" ) - historyMessages.value = historyMessages.value + tempMessage - val displayPath = persistentPath ?: filePath - _audioFilePaths.value = _audioFilePaths.value + (tempMessageId to displayPath) - _audioDurations.value = _audioDurations.value + (tempMessageId to durationSeconds) + messageRepository.insertMessage(tempMessage) + updateConversationLatestMessage( + conversationId = id, + messagePreview = tempMessage.content, + timestamp = messageTimestamp, + ) val segmentFiles = try { - AudioSegmenter.split( - inputPath = filePath, - segmentDurationSeconds = SEGMENT_DURATION_SECONDS, - cacheDir = context.cacheDir, - ) + withContext(ioDispatcher) { + AudioSegmenter.split( + inputPath = filePath, + segmentDurationSeconds = SEGMENT_DURATION_SECONDS, + cacheDir = context.cacheDir, + ) + } } catch (e: Exception) { Log.e(TAG, "音频切片失败: ${e.message}", e) connectionStatus.value = "音频处理失败: ${e.message}" @@ -475,11 +522,13 @@ class CreateMemoryViewModel( } try { - val segments = PendingVoiceSegmentBatchBuilder.build( - segmentFiles = segmentFiles, - conversationId = id, - voiceSessionId = voiceSessionId, - ) + val segments = withContext(ioDispatcher) { + PendingVoiceSegmentBatchBuilder.build( + segmentFiles = segmentFiles, + conversationId = id, + voiceSessionId = voiceSessionId, + ) + } isTyping.value = true Log.d(TAG, "并行发送 ${segments.size} 个音频段,服务端按 segmentIndex 排序拼接") @@ -507,11 +556,14 @@ class CreateMemoryViewModel( Log.e(TAG, "音频段发送失败: ${e.message}", e) connectionStatus.value = "发送失败: ${e.message}" errorMessages.value = (errorMessages.value + "发送失败: ${e.message}").takeLast(10) - historyMessages.value = historyMessages.value.filter { it.id != tempMessageId } - _audioFilePaths.value = _audioFilePaths.value - tempMessageId - _audioDurations.value = _audioDurations.value - tempMessageId + withContext(ioDispatcher) { + messageRepository.deleteMessageById(tempMessageId) + voiceAttachmentRepository.deleteByConversationAndIndex(id, userMessageIndex) + } } finally { - segmentFiles.forEach { it.file.delete() } + withContext(ioDispatcher) { + segmentFiles.forEach { it.file.delete() } + } } } @@ -548,19 +600,23 @@ class CreateMemoryViewModel( conversationId.value?.let { id -> Log.d(TAG, "发送消息到对话: $id") - userMessages.value = userMessages.value + text - - val tempMessage = MessageDto( + val messageTimestamp = System.currentTimeMillis() + val tempMessage = Message( id = "temp_user_${System.currentTimeMillis()}", conversationId = id, content = text, senderType = "user", - timestamp = System.currentTimeMillis(), + timestamp = messageTimestamp, messageType = "text" ) - historyMessages.value = historyMessages.value + tempMessage try { + messageRepository.insertMessage(tempMessage) + updateConversationLatestMessage( + conversationId = id, + messagePreview = text, + timestamp = messageTimestamp, + ) isTyping.value = true conversationRealtime.sendText(id, text) Log.d(TAG, "消息发送成功") @@ -569,7 +625,7 @@ class CreateMemoryViewModel( Log.e(TAG, "消息发送失败: ${e.message}", e) connectionStatus.value = "发送失败: ${e.message}" errorMessages.value = (errorMessages.value + "发送失败: ${e.message}").takeLast(10) - historyMessages.value = historyMessages.value.filter { it.id != tempMessage.id } + messageRepository.deleteMessageById(tempMessage.id) } } ?: run { Log.e(TAG, "对话ID为空,无法发送消息") @@ -603,15 +659,23 @@ class CreateMemoryViewModel( } conversationId.value?.let { id -> - val aiMessage = MessageDto( + val messageTimestamp = System.currentTimeMillis() + val aiMessage = Message( id = "ai_${System.currentTimeMillis()}_$index", conversationId = id, content = text, senderType = "assistant", - timestamp = System.currentTimeMillis(), + timestamp = messageTimestamp, messageType = "text" ) - historyMessages.value = historyMessages.value + aiMessage + viewModelScope.launch { + messageRepository.insertMessage(aiMessage) + updateConversationLatestMessage( + conversationId = id, + messagePreview = text, + timestamp = messageTimestamp, + ) + } } if (index == 0) { @@ -644,15 +708,23 @@ class CreateMemoryViewModel( agentResponse.value = streamingText.value conversationId.value?.let { id -> - val aiMessage = MessageDto( + val messageTimestamp = System.currentTimeMillis() + val aiMessage = Message( id = "ai_${System.currentTimeMillis()}", conversationId = id, content = streamingText.value, senderType = "assistant", - timestamp = System.currentTimeMillis(), + timestamp = messageTimestamp, messageType = "text" ) - historyMessages.value = historyMessages.value + aiMessage + viewModelScope.launch { + messageRepository.insertMessage(aiMessage) + updateConversationLatestMessage( + conversationId = id, + messagePreview = streamingText.value, + timestamp = messageTimestamp, + ) + } } isStreaming.value = false @@ -792,6 +864,7 @@ class CreateMemoryViewModel( override fun onCleared() { super.onCleared() taskPollingJob?.cancel() + localMessagesJob?.cancel() conversationRealtime.close() recordingCoordinator.release() audioPlayer.release() diff --git a/app-android/app/src/test/java/com/huaga/life_echo/data/repository/ConversationRepositoryTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/data/repository/ConversationRepositoryTest.kt index 3f12ebf..5b248ae 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/data/repository/ConversationRepositoryTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/data/repository/ConversationRepositoryTest.kt @@ -71,10 +71,84 @@ class ConversationRepositoryTest { assertEquals(listOf("conversation-1"), dao.deletedConversations.map { it.id }) } + @Test + fun touch_latest_message_updates_preview_and_timestamp() = runTest { + val existing = Conversation( + id = "conversation-1", + userId = "user-1", + startedAt = 1L, + endedAt = null, + durationSeconds = 0, + summary = null, + currentTopic = null, + conversationStage = null, + latestMessagePreview = null, + latestMessageTime = 10L, + ) + val dao = FakeConversationDao(existingConversation = existing) + val repository = ConversationRepository( + conversationDao = dao, + segmentDao = FakeConversationSegmentDao(), + conversationApi = FakeConversationApiPort(), + ) + + repository.touchLatestMessage( + conversationId = "conversation-1", + latestMessagePreview = "刚刚的新消息", + latestMessageTime = 999L, + ) + + assertEquals(1, dao.updatedConversations.size) + assertEquals("刚刚的新消息", dao.updatedConversations.single().latestMessagePreview) + assertEquals(999L, dao.updatedConversations.single().latestMessageTime) + } + + @Test + fun sync_conversations_does_not_overwrite_newer_local_latest_message() = runTest { + val existing = Conversation( + id = "conversation-1", + userId = "user-1", + startedAt = 1L, + endedAt = null, + durationSeconds = 0, + summary = null, + currentTopic = null, + conversationStage = null, + latestMessagePreview = "本地较新的消息", + latestMessageTime = 999L, + ) + val dao = FakeConversationDao(existingConversation = existing) + val api = FakeConversationApiPort( + conversationListResult = Result.success( + listOf( + ConversationListItemDto( + id = "conversation-1", + title = "Title", + avatarUrl = null, + latestMessagePreview = "服务端较旧消息", + latestMessageTime = 123L, + ) + ) + ) + ) + val repository = ConversationRepository( + conversationDao = dao, + segmentDao = FakeConversationSegmentDao(), + conversationApi = api, + ) + + repository.syncConversations() + + assertEquals(1, dao.insertedConversations.size) + assertEquals("本地较新的消息", dao.insertedConversations.single().latestMessagePreview) + assertEquals(999L, dao.insertedConversations.single().latestMessageTime) + } + private class FakeConversationDao( private var existingConversation: Conversation? = null, ) : ConversationDao { val insertedConversations = mutableListOf() + val updatedConversations = mutableListOf() val deletedConversations = mutableListOf() override fun getAllConversations() = flowOf(emptyList()) @@ -85,7 +159,10 @@ class ConversationRepositoryTest { insertedConversations += conversation existingConversation = conversation } - override suspend fun updateConversation(conversation: Conversation) = Unit + override suspend fun updateConversation(conversation: Conversation) { + updatedConversations += conversation + existingConversation = conversation + } override suspend fun deleteConversation(conversation: Conversation) { deletedConversations += conversation if (existingConversation?.id == conversation.id) { diff --git a/app-android/app/src/test/java/com/huaga/life_echo/feature/voice/RecordingDurationEstimatorTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/feature/voice/RecordingDurationEstimatorTest.kt new file mode 100644 index 0000000..982fc42 --- /dev/null +++ b/app-android/app/src/test/java/com/huaga/life_echo/feature/voice/RecordingDurationEstimatorTest.kt @@ -0,0 +1,41 @@ +package com.huaga.life_echo.feature.voice + +import org.junit.Assert.assertEquals +import org.junit.Test + +class RecordingDurationEstimatorTest { + + @Test + fun metadata_duration_rounds_up_to_next_second() { + assertEquals(2, RecordingDurationEstimator.metadataSeconds(durationMs = 1_500)) + assertEquals(1, RecordingDurationEstimator.metadataSeconds(durationMs = 1)) + assertEquals(0, RecordingDurationEstimator.metadataSeconds(durationMs = 0)) + } + + @Test + fun elapsed_duration_uses_wall_clock_instead_of_tick_counting() { + assertEquals(0, RecordingDurationEstimator.elapsedSeconds(startElapsedMs = 10_000, nowElapsedMs = 10_000)) + assertEquals(1, RecordingDurationEstimator.elapsedSeconds(startElapsedMs = 10_000, nowElapsedMs = 11_000)) + assertEquals(34, RecordingDurationEstimator.elapsedSeconds(startElapsedMs = 10_000, nowElapsedMs = 44_999)) + } + + @Test + fun final_duration_prefers_the_more_complete_measurement() { + assertEquals( + 35, + RecordingDurationEstimator.finalDurationSeconds( + startElapsedMs = 10_000, + stopElapsedMs = 44_999, + metadataDurationMs = 34_100, + ), + ) + assertEquals( + 35, + RecordingDurationEstimator.finalDurationSeconds( + startElapsedMs = 10_000, + stopElapsedMs = 44_100, + metadataDurationMs = 34_500, + ), + ) + } +} diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/components/chat/MessageListPresentationTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/components/chat/MessageListPresentationTest.kt new file mode 100644 index 0000000..a398ce5 --- /dev/null +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/components/chat/MessageListPresentationTest.kt @@ -0,0 +1,15 @@ +package com.huaga.life_echo.ui.components.chat + +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test + +class MessageListPresentationTest { + + @Test + fun audio_message_requires_file_and_positive_duration() { + assertTrue(shouldRenderAudioMessage(filePath = "/tmp/voice.m4a", durationSeconds = 12)) + assertFalse(shouldRenderAudioMessage(filePath = "", durationSeconds = 12)) + assertFalse(shouldRenderAudioMessage(filePath = "/tmp/voice.m4a", durationSeconds = 0)) + } +} diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt index 9531dce..35d4af8 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelRecordingCoordinatorTest.kt @@ -28,6 +28,7 @@ import com.huaga.life_echo.feature.voice.RecorderStartResult import com.huaga.life_echo.feature.voice.RecorderStopResult import com.huaga.life_echo.feature.voice.RecordingCoordinator import com.huaga.life_echo.network.WebSocketMessage +import com.huaga.life_echo.network.MessageType import com.huaga.life_echo.model.ChapterDto import com.huaga.life_echo.model.ConversationListItemDto import com.huaga.life_echo.model.CreateConversationResponse @@ -35,9 +36,14 @@ import com.huaga.life_echo.model.MessageDto import com.huaga.life_echo.model.TasksStatusDto import com.huaga.life_echo.testutil.MainDispatcherRule import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.launch import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals @@ -240,6 +246,47 @@ class CreateMemoryViewModelRecordingCoordinatorTest { } } + @Test + fun stop_and_send_recording_as_text_does_not_send_transcript_as_user_text_message() = + runTest(mainDispatcherRule.dispatcher.scheduler) { + val rootDir = Files.createTempDirectory("create-memory-vm").toFile() + val context = newContext(rootDir) + val recorder = FakeRecorderEngine() + recorder.stopResult = RecorderStopResult.Success(durationSeconds = 3) + val coordinator = newRecordingCoordinator( + rootDir = rootDir, + recorder = recorder, + recordingIds = ArrayDeque(listOf("recording-1")), + voiceSessionIds = ArrayDeque(listOf("voice-1")), + ) + val messageDao = FakeMessageDao() + val realtime = TranscriptConversationRealtimePort( + transcriptText = "这是转录文本", + scope = CoroutineScope(mainDispatcherRule.dispatcher), + ) + val viewModel = newViewModel( + context = context, + recordingCoordinator = coordinator, + realtime = realtime, + messageDao = messageDao, + ) + + try { + viewModel.conversationId.value = "conversation-1" + viewModel.startRecordingVoice() + + viewModel.stopAndSendRecordingAsText() + advanceUntilIdle() + + assertEquals("这是转录文本", viewModel.transcript.value) + assertTrue(messageDao.snapshot("conversation-1").isEmpty()) + assertTrue(realtime.sentTexts.isEmpty()) + } finally { + coordinator.cancel() + rootDir.deleteRecursively() + } + } + @Ignore("Requires pending segment pipeline not yet integrated") @Test fun stop_and_send_recording_requests_conversation_creation_when_id_missing() = @@ -366,6 +413,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest { recordingCoordinator: RecordingCoordinator, conversationApi: ConversationApiPort = NoOpConversationApiPort(), realtime: ConversationRealtimePort = NoOpConversationRealtimePort(), + messageDao: FakeMessageDao = FakeMessageDao(), pendingVoiceSegmentStore: PendingVoiceSegmentStore = PendingVoiceSegmentStore( File(context.filesDir, "pending-voice-segments") ), @@ -396,7 +444,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest { chapterDao = FakeChapterDao(), ), messageRepository = MessageRepository( - messageDao = FakeMessageDao(), + messageDao = messageDao, conversationApi = conversationApi, ), voiceAttachmentRepository = VoiceAttachmentRepository(voiceAttachmentDao = FakeVoiceAttachmentDao()), @@ -405,6 +453,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest { conversationRealtime = realtime, tokenInitializer = {}, recordingCoordinator = recordingCoordinator, + ioDispatcher = mainDispatcherRule.dispatcher, ) } @@ -515,20 +564,38 @@ class CreateMemoryViewModelRecordingCoordinatorTest { } private class FakeMessageDao : MessageDao { + private val messages = MutableStateFlow>(emptyList()) + override fun getMessagesByConversationId(conversationId: String) = - flowOf(emptyList()) - override suspend fun getMessageById(id: String): Message? = null - override suspend fun insertMessage(message: Message) = Unit - override suspend fun insertMessages(messages: List) = Unit + messages.map { all -> + all.filter { it.conversationId == conversationId }.sortedBy { it.timestamp } + } + override suspend fun getMessageById(id: String): Message? = messages.value.firstOrNull { it.id == id } + override suspend fun insertMessage(message: Message) { + messages.value = messages.value.filterNot { it.id == message.id } + message + } + override suspend fun insertMessages(messages: List) { + val ids = messages.map { it.id }.toSet() + this.messages.value = this.messages.value.filterNot { it.id in ids } + messages + } override suspend fun updateMessage(message: Message) = Unit - override suspend fun deleteMessage(message: Message) = Unit - override suspend fun deleteMessagesByConversationId(conversationId: String) = Unit + override suspend fun deleteMessage(message: Message) { + messages.value = messages.value.filterNot { it.id == message.id } + } + override suspend fun deleteMessagesByConversationId(conversationId: String) { + messages.value = messages.value.filterNot { it.conversationId == conversationId } + } + + fun snapshot(conversationId: String): List { + return messages.value.filter { it.conversationId == conversationId }.sortedBy { it.timestamp } + } } private class FakeVoiceAttachmentDao : VoiceAttachmentDao { override suspend fun getByConversationId(conversationId: String): List = emptyList() override suspend fun getByConversationAndIndex(conversationId: String, index: Int) = null override suspend fun insert(attachment: VoiceAttachment) = Unit + override suspend fun deleteByConversationAndIndex(conversationId: String, index: Int) = Unit override suspend fun deleteByConversationId(conversationId: String) = Unit } @@ -609,4 +676,53 @@ class CreateMemoryViewModelRecordingCoordinatorTest { override fun setGenerating(generating: Boolean) = Unit override fun close() = Unit } + + private class TranscriptConversationRealtimePort( + private val transcriptText: String, + private val scope: CoroutineScope, + ) : ConversationRealtimePort { + private var onMessageCallback: ((WebSocketMessage) -> Unit)? = null + private val _state = + MutableStateFlow(ConversationRealtimePort.State.NotConnected) + val sentTexts = mutableListOf() + + override val state: StateFlow = _state + + override suspend fun prepare() = Unit + + override suspend fun connect( + conversationId: String, + token: String?, + onMessage: (WebSocketMessage) -> Unit, + onError: ((String) -> Unit)?, + ) { + onMessageCallback = onMessage + _state.value = ConversationRealtimePort.State.Connected(conversationId) + } + + override suspend fun disconnect() = Unit + override fun isConnected(): Boolean = false + override suspend fun sendText(conversationId: String, text: String) { + sentTexts += text + } + override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit + override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit + override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit + override suspend fun sendTranscribeOnly(audioBytes: ByteArray, conversationId: String) { + scope.launch { + onMessageCallback?.invoke( + WebSocketMessage( + type = MessageType.transcript, + conversation_id = conversationId, + data = buildJsonObject { put("text", transcriptText) }, + ) + ) + } + } + override suspend fun sendEndConversation(conversationId: String) = Unit + override suspend fun cancelGeneration(conversationId: String) = Unit + override fun isGenerating(): Boolean = false + override fun setGenerating(generating: Boolean) = Unit + override fun close() = Unit + } } diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt index d546d6b..6f024e5 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/viewmodel/CreateMemoryViewModelWarmupTest.kt @@ -33,6 +33,7 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals @@ -150,9 +151,73 @@ class CreateMemoryViewModelWarmupTest { } } + @Test + fun initialize_conversation_restores_local_messages_when_remote_messages_are_unavailable() = + runTest(mainDispatcherRule.dispatcher.scheduler) { + val rootDir = Files.createTempDirectory("warmup-test").toFile() + val context = newContext(rootDir) + val realtime = FakeConversationRealtimePort() + val messageDao = FakeMessageDao( + initialMessages = listOf( + Message( + id = "local-user-1", + conversationId = "conversation-1", + content = "这是一条本地暂存消息", + senderType = "user", + timestamp = 1_000L, + messageType = "text", + ) + ) + ) + val viewModel = newViewModel( + context, + realtime = realtime, + messageDao = messageDao, + ) + + try { + viewModel.initializeConversation("conversation-1") + advanceUntilIdle() + + assertEquals(1, viewModel.historyMessages.value.size) + assertEquals( + "这是一条本地暂存消息", + viewModel.historyMessages.value.single().content, + ) + } finally { + rootDir.deleteRecursively() + } + } + + @Test + fun send_text_message_persists_local_message_before_remote_roundtrip() = + runTest(mainDispatcherRule.dispatcher.scheduler) { + val rootDir = Files.createTempDirectory("warmup-test").toFile() + val context = newContext(rootDir) + val messageDao = FakeMessageDao() + val realtime = FakeConversationRealtimePort(connected = true) + val viewModel = newViewModel( + context, + realtime = realtime, + messageDao = messageDao, + ) + + try { + viewModel.conversationId.value = "conversation-1" + viewModel.sendTextMessage("你好") + advanceUntilIdle() + + assertEquals(1, messageDao.snapshot("conversation-1").size) + assertEquals("你好", messageDao.snapshot("conversation-1").single().content) + } finally { + rootDir.deleteRecursively() + } + } + private fun newViewModel( context: Context, realtime: ConversationRealtimePort = FakeConversationRealtimePort(), + messageDao: FakeMessageDao = FakeMessageDao(), recordingCoordinator: RecordingCoordinator = RecordingCoordinator( recorder = FakeRecorderEngine(), captureFileFactory = { File(context.cacheDir, "rec_${System.nanoTime()}.m4a") }, @@ -169,7 +234,7 @@ class CreateMemoryViewModelWarmupTest { chapterDao = FakeChapterDao(), ), messageRepository = MessageRepository( - messageDao = FakeMessageDao(), + messageDao = messageDao, conversationApi = conversationApi, ), voiceAttachmentRepository = VoiceAttachmentRepository(voiceAttachmentDao = FakeVoiceAttachmentDao()), @@ -178,6 +243,7 @@ class CreateMemoryViewModelWarmupTest { conversationRealtime = realtime, tokenInitializer = {}, recordingCoordinator = recordingCoordinator, + ioDispatcher = mainDispatcherRule.dispatcher, ) } @@ -312,21 +378,41 @@ class CreateMemoryViewModelWarmupTest { override suspend fun deleteChapter(chapter: Chapter) = Unit } - private class FakeMessageDao : MessageDao { + private class FakeMessageDao( + initialMessages: List = emptyList(), + ) : MessageDao { + private val messages = MutableStateFlow(initialMessages) + override fun getMessagesByConversationId(conversationId: String) = - flowOf(emptyList()) + messages.map { all -> + all.filter { it.conversationId == conversationId }.sortedBy { it.timestamp } + } override suspend fun getMessageById(id: String): Message? = null - override suspend fun insertMessage(message: Message) = Unit - override suspend fun insertMessages(messages: List) = Unit + override suspend fun insertMessage(message: Message) { + messages.value = messages.value.filterNot { it.id == message.id } + message + } + override suspend fun insertMessages(messages: List) { + val ids = messages.map { it.id }.toSet() + this.messages.value = this.messages.value.filterNot { it.id in ids } + messages + } override suspend fun updateMessage(message: Message) = Unit - override suspend fun deleteMessage(message: Message) = Unit - override suspend fun deleteMessagesByConversationId(conversationId: String) = Unit + override suspend fun deleteMessage(message: Message) { + messages.value = messages.value.filterNot { it.id == message.id } + } + override suspend fun deleteMessagesByConversationId(conversationId: String) { + messages.value = messages.value.filterNot { it.conversationId == conversationId } + } + + fun snapshot(conversationId: String): List { + return messages.value.filter { it.conversationId == conversationId }.sortedBy { it.timestamp } + } } private class FakeVoiceAttachmentDao : VoiceAttachmentDao { override suspend fun getByConversationId(conversationId: String): List = emptyList() override suspend fun getByConversationAndIndex(conversationId: String, index: Int) = null override suspend fun insert(attachment: VoiceAttachment) = Unit + override suspend fun deleteByConversationAndIndex(conversationId: String, index: Int) = Unit override suspend fun deleteByConversationId(conversationId: String) = Unit } }