重构: 将 CreateMemoryViewModel 迁移到对话端口

ViewModel 现在依赖 ConversationApiPort 和 ConversationRealtimePort,而不是 ApiService/WebSocketClient。
ViewModelFactory 从 AppContainer 装配这些适配器。
This commit is contained in:
Kevin
2026-03-12 10:36:05 +08:00
parent 7e59c65602
commit 493dedda47
4 changed files with 331 additions and 243 deletions

View File

@@ -8,6 +8,9 @@ import com.huaga.life_echo.data.auth.TokenManager
import com.huaga.life_echo.data.repository.ConversationRepository
import com.huaga.life_echo.data.repository.ChapterRepository
import com.huaga.life_echo.data.repository.MessageRepository
import com.huaga.life_echo.feature.conversation.ports.AudioSegmentRequest
import com.huaga.life_echo.feature.conversation.ports.ConversationApiPort
import com.huaga.life_echo.feature.conversation.ports.ConversationRealtimePort
import com.huaga.life_echo.feature.voice.AudioPlayer
import com.huaga.life_echo.feature.voice.AudioSegmentException
import com.huaga.life_echo.feature.voice.AudioSegmentFile
@@ -23,11 +26,8 @@ import com.huaga.life_echo.feature.voice.PlaybackInfo
import com.huaga.life_echo.feature.voice.ConversationScopedTranscriptRouter
import com.huaga.life_echo.feature.voice.SegmentedVoiceTranscriptState
import com.huaga.life_echo.feature.voice.VoiceRecorder
import com.huaga.life_echo.network.WebSocketClient
import com.huaga.life_echo.network.WebSocketMessage
import com.huaga.life_echo.network.MessageType
import com.huaga.life_echo.network.ApiService
import com.huaga.life_echo.network.AuthService
import com.huaga.life_echo.network.models.MessageDto
import com.huaga.life_echo.data.database.Chapter
import kotlinx.coroutines.channels.Channel
@@ -52,6 +52,8 @@ class CreateMemoryViewModel(
private val chapterRepository: ChapterRepository,
private val messageRepository: MessageRepository,
private val context: Context,
private val conversationApi: ConversationApiPort,
private val conversationRealtime: ConversationRealtimePort,
private val recordingCoordinator: RecordingCoordinator = createDefaultRecordingCoordinator(context),
private val audioPlayer: AudioPlayer = AudioPlayer(context),
private val tokenInitializer: (Context) -> Unit = TokenManager::initialize,
@@ -80,8 +82,6 @@ class CreateMemoryViewModel(
}
}
private val webSocketClient = WebSocketClient()
private val apiService = ApiService(TokenManager, AuthService())
private val pendingVoiceSegmentStore = PendingVoiceSegmentStore(
File(context.filesDir, "pending-voice-segments")
)
@@ -96,45 +96,35 @@ class CreateMemoryViewModel(
val agentResponse = MutableStateFlow("")
val connectionStatus = MutableStateFlow("未连接")
val conversationId = MutableStateFlow<String?>(null)
val userMessages = MutableStateFlow<List<String>>(emptyList()) // 用户发送的文本消息列表
val userMessages = MutableStateFlow<List<String>>(emptyList())
// 历史消息
val historyMessages = MutableStateFlow<List<MessageDto>>(emptyList())
// 流式消息相关
val isStreaming = MutableStateFlow(false) // 是否正在流式接收
val streamingText = MutableStateFlow("") // 流式文本内容
val isTyping = MutableStateFlow(false) // AI是否正在输入
val isStreaming = MutableStateFlow(false)
val streamingText = MutableStateFlow("")
val isTyping = MutableStateFlow(false)
// WebSocket连接状态用于调试
val wsIsConnected = MutableStateFlow(false)
// 后台处理状态
val isProcessing = MutableStateFlow(false) // 是否正在处理回忆录
val processingStatus = MutableStateFlow("") // 处理状态文本
val isProcessing = MutableStateFlow(false)
val processingStatus = MutableStateFlow("")
// 调试信息
val lastMessageType = MutableStateFlow<String?>(null) // 最后收到的消息类型
val lastMessageTime = MutableStateFlow<String?>(null) // 最后消息时间
val errorMessages = MutableStateFlow<List<String>>(emptyList()) // 错误消息列表
val messageCount = MutableStateFlow(0) // 消息计数
val lastMessageType = MutableStateFlow<String?>(null)
val lastMessageTime = MutableStateFlow<String?>(null)
val errorMessages = MutableStateFlow<List<String>>(emptyList())
val messageCount = MutableStateFlow(0)
// 语音录制相关状态
val isVoiceRecording: StateFlow<Boolean> = recordingCoordinator.isRecording
val recordingDuration: StateFlow<Int> = recordingCoordinator.recordingDuration
// 音频播放相关状态
val playbackInfo: StateFlow<PlaybackInfo> = audioPlayer.playbackInfo
// 音频文件路径映射 (messageId -> filePath)
private val _audioFilePaths = MutableStateFlow<Map<String, String>>(emptyMap())
val audioFilePaths: StateFlow<Map<String, String>> = _audioFilePaths.asStateFlow()
// 音频时长映射 (messageId -> duration in seconds)
private val _audioDurations = MutableStateFlow<Map<String, Int>>(emptyMap())
val audioDurations: StateFlow<Map<String, Int>> = _audioDurations.asStateFlow()
// 仅转写结果:发 transcribe_only 后等待 transcript 用
private val pendingTranscribeChannel = Channel<String>(Channel.RENDEZVOUS)
@Volatile
private var waitingForTranscribeOnly = false
@@ -146,9 +136,6 @@ class CreateMemoryViewModel(
}
}
/**
* 初始化对话如果提供了conversationId
*/
fun initializeConversation(convId: String?) {
if (convId == null || convId == "new") {
return
@@ -160,13 +147,13 @@ class CreateMemoryViewModel(
segmentedVoiceTranscriptState = SegmentedVoiceTranscriptState()
try {
// 加载历史消息
conversationRealtime.prepare()
loadHistoryMessages(convId)
// 获取访问令牌并连接WebSocket
val token = TokenManager.getAccessToken()
Log.d(TAG, "初始化对话conversationId: $convId")
webSocketClient.connect(
conversationRealtime.connect(
convId,
token,
onMessage = { message ->
@@ -176,13 +163,9 @@ class CreateMemoryViewModel(
onError = { errorMsg ->
Log.e(TAG, "WebSocket错误: $errorMsg")
connectionStatus.value = "错误: $errorMsg"
// 添加到错误列表最多保留10条
errorMessages.value = (errorMessages.value + errorMsg).takeLast(10)
}
)
// 注意连接状态将在收到服务器的connect消息时更新
// 这里不立即设置为"已连接",等待服务器确认
} catch (e: Exception) {
Log.e(TAG, "初始化对话失败: ${e.message}", e)
connectionStatus.value = "连接失败: ${e.message}"
@@ -190,19 +173,14 @@ class CreateMemoryViewModel(
}
}
/**
* 加载历史消息
*/
private suspend fun loadHistoryMessages(convId: String) {
val result = apiService.getMessages(convId)
val result = conversationApi.getMessages(convId)
result.fold(
onSuccess = { messages ->
historyMessages.value = messages
// 同步到本地数据库
messageRepository.syncMessages(convId)
},
onFailure = { exception ->
// 如果加载失败,记录错误但不阻止连接
connectionStatus.value = "加载历史消息失败: ${exception.message}"
}
)
@@ -213,8 +191,7 @@ class CreateMemoryViewModel(
connectionStatus.value = "创建对话中..."
try {
// 先通过API创建对话
val createResult = apiService.createConversation()
val createResult = conversationApi.createConversation()
val convId = createResult.fold(
onSuccess = { response -> response.id },
onFailure = { exception ->
@@ -228,7 +205,6 @@ class CreateMemoryViewModel(
isRecording.value = true
connectionStatus.value = "连接中..."
// 新建对话:清空当前会话展示状态
historyMessages.value = emptyList()
userMessages.value = emptyList()
segmentedVoiceTranscriptState = SegmentedVoiceTranscriptState()
@@ -239,13 +215,11 @@ class CreateMemoryViewModel(
_audioFilePaths.value = emptyMap()
_audioDurations.value = emptyMap()
// 清除旧的任务记录
apiService.clearTasks()
conversationApi.clearTasks()
// 获取访问令牌并连接WebSocket
val token = TokenManager.getAccessToken()
Log.d(TAG, "开始对话conversationId: $convId")
webSocketClient.connect(
conversationRealtime.connect(
convId,
token,
onMessage = { message ->
@@ -256,12 +230,9 @@ class CreateMemoryViewModel(
Log.e(TAG, "WebSocket错误: $errorMsg")
connectionStatus.value = "错误: $errorMsg"
isRecording.value = false
// 添加到错误列表最多保留10条
errorMessages.value = (errorMessages.value + errorMsg).takeLast(10)
}
)
// 注意连接状态将在收到服务器的connect消息时更新
} catch (e: Exception) {
Log.e(TAG, "开始对话失败: ${e.message}", e)
connectionStatus.value = "连接失败: ${e.message}"
@@ -273,8 +244,8 @@ class CreateMemoryViewModel(
fun endConversation() {
viewModelScope.launch {
conversationId.value?.let { id ->
webSocketClient.sendEndConversation(id)
webSocketClient.disconnect()
conversationRealtime.sendEndConversation(id)
conversationRealtime.disconnect()
connectionStatus.value = "已断开"
wsIsConnected.value = false
isRecording.value = false
@@ -285,7 +256,7 @@ class CreateMemoryViewModel(
fun sendAudioChunk(chunk: ByteArray) {
viewModelScope.launch {
conversationId.value?.let { id ->
webSocketClient.sendAudioChunk(chunk, id)
conversationRealtime.sendAudioChunk(chunk, id)
}
}
}
@@ -344,13 +315,13 @@ class CreateMemoryViewModel(
return@withLock null
}
if (!webSocketClient.isConnected()) {
if (!conversationRealtime.isConnected()) {
Log.w(TAG, "分段发送前 WebSocket 未连接,开始等待重连")
connectionStatus.value = "未连接,正在重连..."
wsIsConnected.value = false
val token = TokenManager.getAccessToken()
try {
webSocketClient.connect(
conversationRealtime.connect(
readyConversationId,
token,
onMessage = { message -> handleWebSocketMessage(message) },
@@ -363,7 +334,7 @@ class CreateMemoryViewModel(
}
val connected = withTimeoutOrNull(SEGMENT_WAIT_TIMEOUT_MS) {
if (webSocketClient.isConnected()) {
if (conversationRealtime.isConnected()) {
true
} else {
wsIsConnected.first { it }
@@ -378,9 +349,6 @@ class CreateMemoryViewModel(
readyConversationId
}
/**
* 开始录音(需 API 26+ 与录音权限)
*/
fun startRecordingVoice() {
Log.d(TAG, "开始录音")
when (val result = recordingCoordinator.start()) {
@@ -570,12 +538,12 @@ class CreateMemoryViewModel(
startConversation()
delay(300)
}
if (!webSocketClient.isConnected()) {
if (!conversationRealtime.isConnected()) {
connectionStatus.value = "未连接,正在重连..."
conversationId.value?.let { id ->
val token = TokenManager.getAccessToken()
try {
webSocketClient.connect(
conversationRealtime.connect(
id,
token,
onMessage = { m -> handleWebSocketMessage(m) },
@@ -594,7 +562,7 @@ class CreateMemoryViewModel(
val id = conversationId.value ?: return@launch
try {
waitingForTranscribeOnly = true
webSocketClient.sendTranscribeOnly(audioBytes, id)
conversationRealtime.sendTranscribeOnly(audioBytes, id)
val text = withTimeoutOrNull(15000L) { pendingTranscribeChannel.receive() }
waitingForTranscribeOnly = false
if (!text.isNullOrBlank() && !text.startsWith("转写失败")) {
@@ -612,9 +580,6 @@ class CreateMemoryViewModel(
}
}
/**
* 发送分段语音(持久化待发队列 + 自动重试,不阻塞后续分段发送)
*/
private suspend fun dispatchPendingVoiceSegment(pendingSegment: PendingVoiceSegment) {
if (!tryAcquirePendingDispatch(pendingSegment.clientSegmentId)) {
return
@@ -646,14 +611,16 @@ class CreateMemoryViewModel(
var lastError: Exception? = null
for (attempt in 1..SEGMENT_RETRY_MAX_ATTEMPTS) {
try {
webSocketClient.sendAudioSegment(
audioBytes = segmentToSend.audioBytes,
conversationId = segmentToSend.conversationId,
voiceSessionId = segmentToSend.voiceSessionId,
segmentIndex = segmentToSend.segmentIndex,
duration = segmentToSend.durationSeconds,
isLast = segmentToSend.isLast,
clientSegmentId = segmentToSend.clientSegmentId,
conversationRealtime.sendAudioSegment(
AudioSegmentRequest(
audioBytes = segmentToSend.audioBytes,
conversationId = segmentToSend.conversationId,
voiceSessionId = segmentToSend.voiceSessionId,
segmentIndex = segmentToSend.segmentIndex,
duration = segmentToSend.durationSeconds,
isLast = segmentToSend.isLast,
clientSegmentId = segmentToSend.clientSegmentId,
)
)
pendingVoiceSegmentStore.remove(segmentToSend.clientSegmentId)
Log.d(TAG, "分段语音发送成功: idx=${segmentToSend.segmentIndex}, attempt=$attempt")
@@ -683,28 +650,22 @@ class CreateMemoryViewModel(
}
}
/**
* 发送音频消息
*/
private suspend fun sendAudioMessage(audioBytes: ByteArray, filePath: String, durationSeconds: Int) {
Log.d(TAG, "准备发送音频消息,大小: ${audioBytes.size}, 时长: ${durationSeconds}s")
// 确保已连接(手势松开后马上发语音,尽量缩短等待)
if (conversationId.value == null) {
Log.d(TAG, "对话ID为空开始创建新对话")
startConversation()
delay(300)
}
// 检查连接状态
if (!webSocketClient.isConnected()) {
if (!conversationRealtime.isConnected()) {
Log.w(TAG, "WebSocket未连接尝试重连")
connectionStatus.value = "未连接,正在重连..."
// 重连逻辑类似 sendTextMessage
conversationId.value?.let { id ->
val token = TokenManager.getAccessToken()
try {
webSocketClient.connect(
conversationRealtime.connect(
id,
token,
onMessage = { message -> handleWebSocketMessage(message) },
@@ -722,10 +683,8 @@ class CreateMemoryViewModel(
}
conversationId.value?.let { id ->
// 生成临时消息 ID
val tempMessageId = "audio_user_${System.currentTimeMillis()}"
// 添加到历史消息(本地先显示)
val tempMessage = MessageDto(
id = tempMessageId,
conversationId = id,
@@ -736,24 +695,18 @@ class CreateMemoryViewModel(
)
historyMessages.value = historyMessages.value + tempMessage
// 保存音频文件路径和时长
_audioFilePaths.value = _audioFilePaths.value + (tempMessageId to filePath)
_audioDurations.value = _audioDurations.value + (tempMessageId to durationSeconds)
try {
// 显示加载动画
isTyping.value = true
// 发送音频消息
webSocketClient.sendAudioMessage(audioBytes, id, durationSeconds)
conversationRealtime.sendAudioMessage(audioBytes, id, durationSeconds)
Log.d(TAG, "音频消息发送成功")
} catch (e: Exception) {
isTyping.value = false
Log.e(TAG, "音频消息发送失败: ${e.message}", e)
connectionStatus.value = "发送失败: ${e.message}"
errorMessages.value = (errorMessages.value + "发送失败: ${e.message}").takeLast(10)
// 移除临时消息
historyMessages.value = historyMessages.value.filter { it.id != tempMessageId }
_audioFilePaths.value = _audioFilePaths.value - tempMessageId
_audioDurations.value = _audioDurations.value - tempMessageId
@@ -763,17 +716,11 @@ class CreateMemoryViewModel(
// ==================== 音频播放功能 ====================
/**
* 播放/暂停音频
*/
fun toggleAudioPlayback(messageId: String, filePath: String) {
Log.d(TAG, "切换音频播放状态: $messageId, $filePath")
audioPlayer.play(messageId, filePath)
}
/**
* 停止音频播放
*/
fun stopAudioPlayback() {
audioPlayer.stop()
}
@@ -784,22 +731,19 @@ class CreateMemoryViewModel(
Log.d(TAG, "准备发送文本消息: $text")
viewModelScope.launch {
// 确保已连接
if (conversationId.value == null) {
Log.d(TAG, "对话ID为空开始创建新对话")
startConversation()
// 等待连接建立
delay(500)
}
// 检查连接状态
if (!webSocketClient.isConnected()) {
if (!conversationRealtime.isConnected()) {
Log.w(TAG, "WebSocket未连接尝试重连")
connectionStatus.value = "未连接,正在重连..."
conversationId.value?.let { id ->
val token = TokenManager.getAccessToken()
try {
webSocketClient.connect(
conversationRealtime.connect(
id,
token,
onMessage = { message ->
@@ -811,7 +755,7 @@ class CreateMemoryViewModel(
connectionStatus.value = "错误: $errorMsg"
}
)
delay(500) // 等待连接建立
delay(500)
} catch (e: Exception) {
Log.e(TAG, "重连失败: ${e.message}", e)
connectionStatus.value = "重连失败: ${e.message}"
@@ -826,10 +770,8 @@ class CreateMemoryViewModel(
conversationId.value?.let { id ->
Log.d(TAG, "发送消息到对话: $id")
// 添加到用户消息列表
userMessages.value = userMessages.value + text
// 添加到历史消息(临时,等待服务器确认)
val tempMessage = MessageDto(
id = "temp_user_${System.currentTimeMillis()}",
conversationId = id,
@@ -840,19 +782,15 @@ class CreateMemoryViewModel(
)
historyMessages.value = historyMessages.value + tempMessage
// 发送文本消息
try {
// 立即显示加载动画
isTyping.value = true
webSocketClient.sendTextMessage(text, id)
conversationRealtime.sendText(id, text)
Log.d(TAG, "消息发送成功")
} catch (e: Exception) {
// 发送失败时隐藏加载动画
isTyping.value = false
Log.e(TAG, "消息发送失败: ${e.message}", e)
connectionStatus.value = "发送失败: ${e.message}"
errorMessages.value = (errorMessages.value + "发送失败: ${e.message}").takeLast(10)
// 移除临时消息
historyMessages.value = historyMessages.value.filter { it.id != tempMessage.id }
}
} ?: run {
@@ -863,7 +801,6 @@ class CreateMemoryViewModel(
}
private fun handleWebSocketMessage(message: WebSocketMessage) {
// 更新调试信息
lastMessageType.value = message.type.name
lastMessageTime.value = java.text.SimpleDateFormat("HH:mm:ss", java.util.Locale.getDefault()).format(java.util.Date())
messageCount.value = messageCount.value + 1
@@ -916,7 +853,6 @@ class CreateMemoryViewModel(
}
}
MessageType.agent_response -> {
// 处理Agent回复可能有多条消息每条作为单独气泡显示
val text = message.getString("text") ?: ""
val isTransition = message.getBoolean("transition") == true
val index = message.getInt("index") ?: 0
@@ -927,12 +863,10 @@ class CreateMemoryViewModel(
return
}
// 收到第一条回复时,隐藏打字指示器
if (index == 0) {
isTyping.value = false
}
// 每条消息立即作为单独的气泡添加到历史消息
conversationId.value?.let { id ->
val aiMessage = MessageDto(
id = "ai_${System.currentTimeMillis()}_$index",
@@ -945,42 +879,35 @@ class CreateMemoryViewModel(
historyMessages.value = historyMessages.value + aiMessage
}
// 更新 agentResponse用于显示最新回复
if (index == 0) {
agentResponse.value = text
} else {
agentResponse.value += "\n\n$text"
}
// 如果是最后一条消息,结束流式状态
if (index >= total - 1) {
isStreaming.value = false
streamingText.value = ""
isTyping.value = false
webSocketClient.setGenerating(false)
conversationRealtime.setGenerating(false)
}
}
MessageType.agent_response_start -> {
// 流式回复开始
isStreaming.value = true
isTyping.value = false // 流式开始时隐藏打字指示器
isTyping.value = false
streamingText.value = ""
webSocketClient.setGenerating(true)
conversationRealtime.setGenerating(true)
}
MessageType.agent_response_chunk -> {
// 流式回复片段
val chunk = message.getString("text") ?: ""
streamingText.value += chunk
// 收到第一个片段时,隐藏打字指示器
if (streamingText.value == chunk) {
isTyping.value = false
}
}
MessageType.agent_response_end -> {
// 流式回复结束
agentResponse.value = streamingText.value
// 添加到历史消息
conversationId.value?.let { id ->
val aiMessage = MessageDto(
id = "ai_${System.currentTimeMillis()}",
@@ -995,15 +922,12 @@ class CreateMemoryViewModel(
isStreaming.value = false
streamingText.value = ""
webSocketClient.setGenerating(false)
conversationRealtime.setGenerating(false)
}
MessageType.agent_typing -> {
// AI正在输入
isTyping.value = true
}
MessageType.text -> {
// 处理文本消息响应(如果需要)
}
MessageType.text -> { }
MessageType.connect -> {
Log.d(TAG, "收到连接确认消息,设置状态为已连接")
connectionStatus.value = "已连接"
@@ -1016,8 +940,7 @@ class CreateMemoryViewModel(
connectionStatus.value = "对话已结束"
isRecording.value = false
isStreaming.value = false
webSocketClient.setGenerating(false)
// 触发对话结束后的处理
conversationRealtime.setGenerating(false)
handleConversationEnded()
}
MessageType.error -> {
@@ -1027,19 +950,16 @@ class CreateMemoryViewModel(
wsIsConnected.value = false
isStreaming.value = false
isTyping.value = false
webSocketClient.setGenerating(false)
conversationRealtime.setGenerating(false)
}
else -> {}
}
}
/**
* 取消当前正在生成的回复
*/
fun cancelGeneration() {
viewModelScope.launch {
conversationId.value?.let { id ->
webSocketClient.cancelGeneration(id)
conversationRealtime.cancelGeneration(id)
isStreaming.value = false
streamingText.value = ""
isTyping.value = false
@@ -1047,26 +967,20 @@ class CreateMemoryViewModel(
}
}
/**
* 处理对话结束后的逻辑
* 等待后台任务完成,然后刷新章节列表
*/
private fun handleConversationEnded() {
viewModelScope.launch {
isProcessing.value = true
processingStatus.value = "正在处理回忆录..."
// 等待后台处理完成最多等待3分钟
val maxWaitSeconds = 180
val checkInterval = 5 // 每5秒检查一次
val checkInterval = 5
var elapsed = 0
while (elapsed < maxWaitSeconds) {
delay(checkInterval * 1000L)
elapsed += checkInterval
// 检查任务状态
val tasksStatus = apiService.getTasksStatus()
val tasksStatus = conversationApi.getTasksStatus()
tasksStatus.fold(
onSuccess = { status ->
val total = status.total
@@ -1078,7 +992,6 @@ class CreateMemoryViewModel(
processingStatus.value = "处理中: 总任务$total, 等待$pending, 运行$running, 成功$success, 失败$failure"
// 如果所有任务完成,刷新章节列表
if (total > 0 && allCompleted) {
processingStatus.value = "处理完成!"
refreshChapters()
@@ -1086,10 +999,8 @@ class CreateMemoryViewModel(
return@launch
}
// 如果没有任务但有章节内容,也认为完成
if (total == 0 && elapsed > 30) {
// 检查是否有章节
val chaptersResult = apiService.getChapters()
val chaptersResult = conversationApi.getChapters()
chaptersResult.fold(
onSuccess = { chapters ->
if (chapters.isNotEmpty()) {
@@ -1109,21 +1020,16 @@ class CreateMemoryViewModel(
)
}
// 超时后也刷新章节列表
processingStatus.value = "处理超时,正在刷新章节..."
refreshChapters()
isProcessing.value = false
}
}
/**
* 刷新章节列表
*/
private suspend fun refreshChapters() {
val chaptersResult = apiService.getChapters()
val chaptersResult = conversationApi.getChapters()
chaptersResult.fold(
onSuccess = { chapterDtos ->
// 转换为本地Chapter实体并保存
val chapters = chapterDtos.map { dto ->
Chapter(
id = dto.id,
@@ -1154,7 +1060,7 @@ class CreateMemoryViewModel(
super.onCleared()
pendingSegmentRetryJob?.cancel()
viewModelScope.launch {
webSocketClient.disconnect()
conversationRealtime.disconnect()
}
recordingCoordinator.release()
audioPlayer.release()

View File

@@ -3,83 +3,39 @@ package com.huaga.life_echo.ui.viewmodel
import android.content.Context
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import com.huaga.life_echo.config.AppConfig
import com.huaga.life_echo.data.auth.TokenManager
import com.huaga.life_echo.data.database.AppDatabase
import com.huaga.life_echo.data.repository.*
import com.huaga.life_echo.network.ApiService
import com.huaga.life_echo.network.AuthService
import com.huaga.life_echo.payment.PaymentManager
import com.huaga.life_echo.app.AppContainer
import com.huaga.life_echo.app.LifeEchoApp
import com.huaga.life_echo.feature.conversation.adapters.ConversationApiAdapter
import com.huaga.life_echo.feature.conversation.adapters.ConversationRealtimeAdapter
import com.huaga.life_echo.feature.memoir.adapters.MemoirApiAdapter
class ViewModelFactory(private val context: Context) : ViewModelProvider.Factory {
init {
// 初始化TokenManager
TokenManager.initialize(context)
}
private val database by lazy { AppDatabase.getDatabase(context) }
private val authService by lazy { AuthService() }
private val apiService by lazy {
ApiService(
tokenManager = TokenManager,
authService = authService
)
}
private val conversationRepository by lazy {
ConversationRepository(
conversationDao = database.conversationDao(),
segmentDao = database.conversationSegmentDao(),
apiService = apiService
)
}
private val chapterRepository by lazy {
ChapterRepository(
chapterDao = database.chapterDao()
)
}
private val messageRepository by lazy {
MessageRepository(
messageDao = database.messageDao(),
apiService = apiService
)
}
private val paymentRepository by lazy {
PaymentRepository(apiService = apiService)
}
private val profileRepository by lazy {
ProfileRepository(apiService = apiService)
}
private val paymentManager by lazy {
PaymentManager(
context = context.applicationContext,
wechatAppId = AppConfig.WECHAT_APP_ID
)
}
private val container: AppContainer
get() = (context.applicationContext as LifeEchoApp).container
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
return when {
modelClass.isAssignableFrom(CreateMemoryViewModel::class.java) -> {
CreateMemoryViewModel(
conversationRepository = conversationRepository,
chapterRepository = chapterRepository,
messageRepository = messageRepository,
context = context
conversationRepository = container.conversationRepository,
chapterRepository = container.chapterRepository,
messageRepository = container.messageRepository,
context = context,
conversationApi = ConversationApiAdapter(container.apiService),
conversationRealtime = ConversationRealtimeAdapter(container.webSocketClient),
) as T
}
modelClass.isAssignableFrom(ConversationListViewModel::class.java) -> {
ConversationListViewModel(
conversationRepository = conversationRepository
conversationRepository = container.conversationRepository
) as T
}
modelClass.isAssignableFrom(MyMemoirViewModel::class.java) -> {
MyMemoirViewModel(
chapterRepository = chapterRepository,
apiService = apiService
chapterRepository = container.chapterRepository,
memoirApi = container.memoirApi
) as T
}
modelClass.isAssignableFrom(AuthViewModel::class.java) -> {
@@ -87,17 +43,16 @@ class ViewModelFactory(private val context: Context) : ViewModelProvider.Factory
}
modelClass.isAssignableFrom(PaymentViewModel::class.java) -> {
PaymentViewModel(
paymentRepository = paymentRepository,
paymentManager = paymentManager
paymentRepository = container.paymentRepository,
paymentManager = container.paymentManager
) as T
}
modelClass.isAssignableFrom(ProfileViewModel::class.java) -> {
ProfileViewModel(
profileRepository = profileRepository
profileRepository = container.profileRepository
) as T
}
else -> throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}

View File

@@ -12,14 +12,23 @@ import com.huaga.life_echo.data.database.MessageDao
import com.huaga.life_echo.data.repository.ChapterRepository
import com.huaga.life_echo.data.repository.ConversationRepository
import com.huaga.life_echo.data.repository.MessageRepository
import com.huaga.life_echo.feature.conversation.ports.AudioSegmentRequest
import com.huaga.life_echo.feature.conversation.ports.ConversationApiPort
import com.huaga.life_echo.feature.conversation.ports.ConversationRealtimePort
import com.huaga.life_echo.feature.voice.RecorderEngine
import com.huaga.life_echo.feature.voice.RecorderStartResult
import com.huaga.life_echo.feature.voice.RecorderStopResult
import com.huaga.life_echo.feature.voice.RecordingCoordinator
import com.huaga.life_echo.network.ApiService
import com.huaga.life_echo.network.WebSocketMessage
import com.huaga.life_echo.network.models.ChapterDto
import com.huaga.life_echo.network.models.CreateConversationResponse
import com.huaga.life_echo.network.models.MessageDto
import com.huaga.life_echo.network.models.TasksStatusDto
import com.huaga.life_echo.testutil.MainDispatcherRule
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
@@ -141,6 +150,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
}
}
@Suppress("DEPRECATION")
private fun newViewModel(
context: Context,
recordingCoordinator: RecordingCoordinator,
@@ -160,6 +170,8 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
apiService = apiService,
),
context = context,
conversationApi = NoOpConversationApiPort(),
conversationRealtime = NoOpConversationRealtimePort(),
recordingCoordinator = recordingCoordinator,
tokenInitializer = {},
)
@@ -234,63 +246,77 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
private class FakeConversationDao : ConversationDao {
override fun getAllConversations() = flowOf(emptyList<Conversation>())
override suspend fun getConversationById(id: String): Conversation? = null
override suspend fun getLatestEmptyConversation(): Conversation? = null
override suspend fun insertConversation(conversation: Conversation) = Unit
override suspend fun updateConversation(conversation: Conversation) = Unit
override suspend fun deleteConversation(conversation: Conversation) = Unit
override suspend fun deleteOtherEmptyConversations(keepId: String) = Unit
}
private class FakeConversationSegmentDao : ConversationSegmentDao {
override fun getSegmentsByConversationId(conversationId: String) =
flowOf(emptyList<ConversationSegment>())
override suspend fun insertSegment(segment: ConversationSegment) = Unit
override suspend fun insertSegments(segments: List<ConversationSegment>) = Unit
override suspend fun updateSegment(segment: ConversationSegment) = Unit
override suspend fun deleteSegment(segment: ConversationSegment) = Unit
}
private class FakeChapterDao : ChapterDao {
override fun getAllChapters() = flowOf(emptyList<Chapter>())
override suspend fun getChapterById(id: String): Chapter? = null
override suspend fun getChaptersByCategory(category: String): List<Chapter> = emptyList()
override suspend fun insertChapter(chapter: Chapter) = Unit
override suspend fun insertChapters(chapters: List<Chapter>) = Unit
override suspend fun updateChapter(chapter: Chapter) = Unit
override suspend fun deleteChapter(chapter: Chapter) = Unit
}
private class FakeMessageDao : MessageDao {
override fun getMessagesByConversationId(conversationId: String) =
flowOf(emptyList<Message>())
override suspend fun getMessageById(id: String): Message? = null
override suspend fun insertMessage(message: Message) = Unit
override suspend fun insertMessages(messages: List<Message>) = Unit
override suspend fun updateMessage(message: Message) = Unit
override suspend fun deleteMessage(message: Message) = Unit
override suspend fun deleteMessagesByConversationId(conversationId: String) = Unit
}
private class NoOpConversationApiPort : ConversationApiPort {
override suspend fun createConversation(): Result<CreateConversationResponse> =
Result.failure(Exception("no-op"))
override suspend fun getMessages(conversationId: String): Result<List<MessageDto>> =
Result.success(emptyList())
override suspend fun getTasksStatus(): Result<TasksStatusDto> =
Result.failure(Exception("no-op"))
override suspend fun getChapters(): Result<List<ChapterDto>> =
Result.success(emptyList())
override suspend fun clearTasks(): Result<Unit> =
Result.success(Unit)
}
private class NoOpConversationRealtimePort : ConversationRealtimePort {
override val state: StateFlow<ConversationRealtimePort.State> =
MutableStateFlow(ConversationRealtimePort.State.NotConnected)
override suspend fun prepare() = Unit
override suspend fun connect(
conversationId: String,
token: String?,
onMessage: (WebSocketMessage) -> Unit,
onError: ((String) -> Unit)?,
) = Unit
override suspend fun disconnect() = Unit
override fun isConnected(): Boolean = false
override suspend fun sendText(conversationId: String, text: String) = Unit
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit
override suspend fun sendTranscribeOnly(audioBytes: ByteArray, conversationId: String) = Unit
override suspend fun sendEndConversation(conversationId: String) = Unit
override suspend fun cancelGeneration(conversationId: String) = Unit
override fun isGenerating(): Boolean = false
override fun setGenerating(generating: Boolean) = Unit
}
}

View File

@@ -0,0 +1,201 @@
package com.huaga.life_echo.ui.viewmodel
import android.content.Context
import com.huaga.life_echo.data.database.Chapter
import com.huaga.life_echo.data.database.ChapterDao
import com.huaga.life_echo.data.database.Conversation
import com.huaga.life_echo.data.database.ConversationDao
import com.huaga.life_echo.data.database.ConversationSegment
import com.huaga.life_echo.data.database.ConversationSegmentDao
import com.huaga.life_echo.data.database.Message
import com.huaga.life_echo.data.database.MessageDao
import com.huaga.life_echo.data.repository.ChapterRepository
import com.huaga.life_echo.data.repository.ConversationRepository
import com.huaga.life_echo.data.repository.MessageRepository
import com.huaga.life_echo.feature.conversation.ports.AudioSegmentRequest
import com.huaga.life_echo.feature.conversation.ports.ConversationApiPort
import com.huaga.life_echo.feature.conversation.ports.ConversationRealtimePort
import com.huaga.life_echo.network.ApiService
import com.huaga.life_echo.network.WebSocketMessage
import com.huaga.life_echo.network.models.ChapterDto
import com.huaga.life_echo.network.models.CreateConversationResponse
import com.huaga.life_echo.network.models.MessageDto
import com.huaga.life_echo.network.models.TasksStatusDto
import com.huaga.life_echo.testutil.MainDispatcherRule
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
import org.junit.Rule
import org.junit.Test
import org.mockito.Mockito
import java.io.File
import java.nio.file.Files
@OptIn(ExperimentalCoroutinesApi::class)
class CreateMemoryViewModelWarmupTest {
@get:Rule
val mainDispatcherRule = MainDispatcherRule()
@Test
fun initialize_conversation_prepares_realtime_before_connect() =
runTest(mainDispatcherRule.dispatcher.scheduler) {
val rootDir = Files.createTempDirectory("warmup-test").toFile()
val context = newContext(rootDir)
val realtime = FakeConversationRealtimePort()
val viewModel = newViewModel(context, realtime = realtime)
try {
viewModel.initializeConversation("conversation-1")
advanceUntilIdle()
assertEquals(listOf("prepare", "connect:conversation-1"), realtime.calls)
} finally {
rootDir.deleteRecursively()
}
}
@Test
fun initialize_conversation_with_null_does_not_prepare() =
runTest(mainDispatcherRule.dispatcher.scheduler) {
val rootDir = Files.createTempDirectory("warmup-test").toFile()
val context = newContext(rootDir)
val realtime = FakeConversationRealtimePort()
val viewModel = newViewModel(context, realtime = realtime)
try {
viewModel.initializeConversation(null)
advanceUntilIdle()
assertEquals(emptyList<String>(), realtime.calls)
} finally {
rootDir.deleteRecursively()
}
}
@Suppress("DEPRECATION")
private fun newViewModel(
context: Context,
realtime: ConversationRealtimePort = FakeConversationRealtimePort(),
): CreateMemoryViewModel {
val apiService = ApiService()
return CreateMemoryViewModel(
conversationRepository = ConversationRepository(
conversationDao = FakeConversationDao(),
segmentDao = FakeConversationSegmentDao(),
apiService = apiService,
),
chapterRepository = ChapterRepository(
chapterDao = FakeChapterDao(),
),
messageRepository = MessageRepository(
messageDao = FakeMessageDao(),
apiService = apiService,
),
context = context,
conversationApi = FakeConversationApiPort(),
conversationRealtime = realtime,
tokenInitializer = {},
)
}
private fun newContext(rootDir: File): Context {
val filesDir = File(rootDir, "files").apply { mkdirs() }
val cacheDir = File(rootDir, "cache").apply { mkdirs() }
val context = Mockito.mock(Context::class.java)
Mockito.`when`(context.filesDir).thenReturn(filesDir)
Mockito.`when`(context.cacheDir).thenReturn(cacheDir)
Mockito.`when`(context.applicationContext).thenReturn(context)
return context
}
private class FakeConversationRealtimePort : ConversationRealtimePort {
val calls = mutableListOf<String>()
override val state: StateFlow<ConversationRealtimePort.State> =
MutableStateFlow(ConversationRealtimePort.State.NotConnected)
override suspend fun prepare() {
calls += "prepare"
}
override suspend fun connect(
conversationId: String,
token: String?,
onMessage: (WebSocketMessage) -> Unit,
onError: ((String) -> Unit)?,
) {
calls += "connect:$conversationId"
}
override suspend fun disconnect() {
calls += "disconnect"
}
override fun isConnected(): Boolean = false
override suspend fun sendText(conversationId: String, text: String) = Unit
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit
override suspend fun sendTranscribeOnly(audioBytes: ByteArray, conversationId: String) = Unit
override suspend fun sendEndConversation(conversationId: String) = Unit
override suspend fun cancelGeneration(conversationId: String) = Unit
override fun isGenerating(): Boolean = false
override fun setGenerating(generating: Boolean) = Unit
}
private class FakeConversationApiPort : ConversationApiPort {
override suspend fun createConversation(): Result<CreateConversationResponse> =
Result.failure(Exception("no-op"))
override suspend fun getMessages(conversationId: String): Result<List<MessageDto>> =
Result.failure(Exception("test-offline"))
override suspend fun getTasksStatus(): Result<TasksStatusDto> =
Result.failure(Exception("no-op"))
override suspend fun getChapters(): Result<List<ChapterDto>> =
Result.success(emptyList())
override suspend fun clearTasks(): Result<Unit> = Result.success(Unit)
}
private class FakeConversationDao : ConversationDao {
override fun getAllConversations() = flowOf(emptyList<Conversation>())
override suspend fun getConversationById(id: String): Conversation? = null
override suspend fun getLatestEmptyConversation(): Conversation? = null
override suspend fun insertConversation(conversation: Conversation) = Unit
override suspend fun updateConversation(conversation: Conversation) = Unit
override suspend fun deleteConversation(conversation: Conversation) = Unit
override suspend fun deleteOtherEmptyConversations(keepId: String) = Unit
}
private class FakeConversationSegmentDao : ConversationSegmentDao {
override fun getSegmentsByConversationId(conversationId: String) =
flowOf(emptyList<ConversationSegment>())
override suspend fun insertSegment(segment: ConversationSegment) = Unit
override suspend fun insertSegments(segments: List<ConversationSegment>) = Unit
override suspend fun updateSegment(segment: ConversationSegment) = Unit
override suspend fun deleteSegment(segment: ConversationSegment) = Unit
}
private class FakeChapterDao : ChapterDao {
override fun getAllChapters() = flowOf(emptyList<Chapter>())
override suspend fun getChapterById(id: String): Chapter? = null
override suspend fun getChaptersByCategory(category: String): List<Chapter> = emptyList()
override suspend fun insertChapter(chapter: Chapter) = Unit
override suspend fun insertChapters(chapters: List<Chapter>) = Unit
override suspend fun updateChapter(chapter: Chapter) = Unit
override suspend fun deleteChapter(chapter: Chapter) = Unit
}
private class FakeMessageDao : MessageDao {
override fun getMessagesByConversationId(conversationId: String) =
flowOf(emptyList<Message>())
override suspend fun getMessageById(id: String): Message? = null
override suspend fun insertMessage(message: Message) = Unit
override suspend fun insertMessages(messages: List<Message>) = Unit
override suspend fun updateMessage(message: Message) = Unit
override suspend fun deleteMessage(message: Message) = Unit
override suspend fun deleteMessagesByConversationId(conversationId: String) = Unit
}
}