重构: 将 CreateMemoryViewModel 迁移到对话端口
ViewModel 现在依赖 ConversationApiPort 和 ConversationRealtimePort,而不是 ApiService/WebSocketClient。 ViewModelFactory 从 AppContainer 装配这些适配器。
This commit is contained in:
@@ -8,6 +8,9 @@ import com.huaga.life_echo.data.auth.TokenManager
|
||||
import com.huaga.life_echo.data.repository.ConversationRepository
|
||||
import com.huaga.life_echo.data.repository.ChapterRepository
|
||||
import com.huaga.life_echo.data.repository.MessageRepository
|
||||
import com.huaga.life_echo.feature.conversation.ports.AudioSegmentRequest
|
||||
import com.huaga.life_echo.feature.conversation.ports.ConversationApiPort
|
||||
import com.huaga.life_echo.feature.conversation.ports.ConversationRealtimePort
|
||||
import com.huaga.life_echo.feature.voice.AudioPlayer
|
||||
import com.huaga.life_echo.feature.voice.AudioSegmentException
|
||||
import com.huaga.life_echo.feature.voice.AudioSegmentFile
|
||||
@@ -23,11 +26,8 @@ import com.huaga.life_echo.feature.voice.PlaybackInfo
|
||||
import com.huaga.life_echo.feature.voice.ConversationScopedTranscriptRouter
|
||||
import com.huaga.life_echo.feature.voice.SegmentedVoiceTranscriptState
|
||||
import com.huaga.life_echo.feature.voice.VoiceRecorder
|
||||
import com.huaga.life_echo.network.WebSocketClient
|
||||
import com.huaga.life_echo.network.WebSocketMessage
|
||||
import com.huaga.life_echo.network.MessageType
|
||||
import com.huaga.life_echo.network.ApiService
|
||||
import com.huaga.life_echo.network.AuthService
|
||||
import com.huaga.life_echo.network.models.MessageDto
|
||||
import com.huaga.life_echo.data.database.Chapter
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
@@ -52,6 +52,8 @@ class CreateMemoryViewModel(
|
||||
private val chapterRepository: ChapterRepository,
|
||||
private val messageRepository: MessageRepository,
|
||||
private val context: Context,
|
||||
private val conversationApi: ConversationApiPort,
|
||||
private val conversationRealtime: ConversationRealtimePort,
|
||||
private val recordingCoordinator: RecordingCoordinator = createDefaultRecordingCoordinator(context),
|
||||
private val audioPlayer: AudioPlayer = AudioPlayer(context),
|
||||
private val tokenInitializer: (Context) -> Unit = TokenManager::initialize,
|
||||
@@ -80,8 +82,6 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
|
||||
private val webSocketClient = WebSocketClient()
|
||||
private val apiService = ApiService(TokenManager, AuthService())
|
||||
private val pendingVoiceSegmentStore = PendingVoiceSegmentStore(
|
||||
File(context.filesDir, "pending-voice-segments")
|
||||
)
|
||||
@@ -96,45 +96,35 @@ class CreateMemoryViewModel(
|
||||
val agentResponse = MutableStateFlow("")
|
||||
val connectionStatus = MutableStateFlow("未连接")
|
||||
val conversationId = MutableStateFlow<String?>(null)
|
||||
val userMessages = MutableStateFlow<List<String>>(emptyList()) // 用户发送的文本消息列表
|
||||
val userMessages = MutableStateFlow<List<String>>(emptyList())
|
||||
|
||||
// 历史消息
|
||||
val historyMessages = MutableStateFlow<List<MessageDto>>(emptyList())
|
||||
|
||||
// 流式消息相关
|
||||
val isStreaming = MutableStateFlow(false) // 是否正在流式接收
|
||||
val streamingText = MutableStateFlow("") // 流式文本内容
|
||||
val isTyping = MutableStateFlow(false) // AI是否正在输入
|
||||
val isStreaming = MutableStateFlow(false)
|
||||
val streamingText = MutableStateFlow("")
|
||||
val isTyping = MutableStateFlow(false)
|
||||
|
||||
// WebSocket连接状态(用于调试)
|
||||
val wsIsConnected = MutableStateFlow(false)
|
||||
|
||||
// 后台处理状态
|
||||
val isProcessing = MutableStateFlow(false) // 是否正在处理回忆录
|
||||
val processingStatus = MutableStateFlow("") // 处理状态文本
|
||||
val isProcessing = MutableStateFlow(false)
|
||||
val processingStatus = MutableStateFlow("")
|
||||
|
||||
// 调试信息
|
||||
val lastMessageType = MutableStateFlow<String?>(null) // 最后收到的消息类型
|
||||
val lastMessageTime = MutableStateFlow<String?>(null) // 最后消息时间
|
||||
val errorMessages = MutableStateFlow<List<String>>(emptyList()) // 错误消息列表
|
||||
val messageCount = MutableStateFlow(0) // 消息计数
|
||||
val lastMessageType = MutableStateFlow<String?>(null)
|
||||
val lastMessageTime = MutableStateFlow<String?>(null)
|
||||
val errorMessages = MutableStateFlow<List<String>>(emptyList())
|
||||
val messageCount = MutableStateFlow(0)
|
||||
|
||||
// 语音录制相关状态
|
||||
val isVoiceRecording: StateFlow<Boolean> = recordingCoordinator.isRecording
|
||||
val recordingDuration: StateFlow<Int> = recordingCoordinator.recordingDuration
|
||||
|
||||
// 音频播放相关状态
|
||||
val playbackInfo: StateFlow<PlaybackInfo> = audioPlayer.playbackInfo
|
||||
|
||||
// 音频文件路径映射 (messageId -> filePath)
|
||||
private val _audioFilePaths = MutableStateFlow<Map<String, String>>(emptyMap())
|
||||
val audioFilePaths: StateFlow<Map<String, String>> = _audioFilePaths.asStateFlow()
|
||||
|
||||
// 音频时长映射 (messageId -> duration in seconds)
|
||||
private val _audioDurations = MutableStateFlow<Map<String, Int>>(emptyMap())
|
||||
val audioDurations: StateFlow<Map<String, Int>> = _audioDurations.asStateFlow()
|
||||
|
||||
// 仅转写结果:发 transcribe_only 后等待 transcript 用
|
||||
private val pendingTranscribeChannel = Channel<String>(Channel.RENDEZVOUS)
|
||||
@Volatile
|
||||
private var waitingForTranscribeOnly = false
|
||||
@@ -146,9 +136,6 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化对话(如果提供了conversationId)
|
||||
*/
|
||||
fun initializeConversation(convId: String?) {
|
||||
if (convId == null || convId == "new") {
|
||||
return
|
||||
@@ -160,13 +147,13 @@ class CreateMemoryViewModel(
|
||||
segmentedVoiceTranscriptState = SegmentedVoiceTranscriptState()
|
||||
|
||||
try {
|
||||
// 加载历史消息
|
||||
conversationRealtime.prepare()
|
||||
|
||||
loadHistoryMessages(convId)
|
||||
|
||||
// 获取访问令牌并连接WebSocket
|
||||
val token = TokenManager.getAccessToken()
|
||||
Log.d(TAG, "初始化对话,conversationId: $convId")
|
||||
webSocketClient.connect(
|
||||
conversationRealtime.connect(
|
||||
convId,
|
||||
token,
|
||||
onMessage = { message ->
|
||||
@@ -176,13 +163,9 @@ class CreateMemoryViewModel(
|
||||
onError = { errorMsg ->
|
||||
Log.e(TAG, "WebSocket错误: $errorMsg")
|
||||
connectionStatus.value = "错误: $errorMsg"
|
||||
// 添加到错误列表(最多保留10条)
|
||||
errorMessages.value = (errorMessages.value + errorMsg).takeLast(10)
|
||||
}
|
||||
)
|
||||
|
||||
// 注意:连接状态将在收到服务器的connect消息时更新
|
||||
// 这里不立即设置为"已连接",等待服务器确认
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "初始化对话失败: ${e.message}", e)
|
||||
connectionStatus.value = "连接失败: ${e.message}"
|
||||
@@ -190,19 +173,14 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 加载历史消息
|
||||
*/
|
||||
private suspend fun loadHistoryMessages(convId: String) {
|
||||
val result = apiService.getMessages(convId)
|
||||
val result = conversationApi.getMessages(convId)
|
||||
result.fold(
|
||||
onSuccess = { messages ->
|
||||
historyMessages.value = messages
|
||||
// 同步到本地数据库
|
||||
messageRepository.syncMessages(convId)
|
||||
},
|
||||
onFailure = { exception ->
|
||||
// 如果加载失败,记录错误但不阻止连接
|
||||
connectionStatus.value = "加载历史消息失败: ${exception.message}"
|
||||
}
|
||||
)
|
||||
@@ -213,8 +191,7 @@ class CreateMemoryViewModel(
|
||||
connectionStatus.value = "创建对话中..."
|
||||
|
||||
try {
|
||||
// 先通过API创建对话
|
||||
val createResult = apiService.createConversation()
|
||||
val createResult = conversationApi.createConversation()
|
||||
val convId = createResult.fold(
|
||||
onSuccess = { response -> response.id },
|
||||
onFailure = { exception ->
|
||||
@@ -228,7 +205,6 @@ class CreateMemoryViewModel(
|
||||
isRecording.value = true
|
||||
connectionStatus.value = "连接中..."
|
||||
|
||||
// 新建对话:清空当前会话展示状态
|
||||
historyMessages.value = emptyList()
|
||||
userMessages.value = emptyList()
|
||||
segmentedVoiceTranscriptState = SegmentedVoiceTranscriptState()
|
||||
@@ -239,13 +215,11 @@ class CreateMemoryViewModel(
|
||||
_audioFilePaths.value = emptyMap()
|
||||
_audioDurations.value = emptyMap()
|
||||
|
||||
// 清除旧的任务记录
|
||||
apiService.clearTasks()
|
||||
conversationApi.clearTasks()
|
||||
|
||||
// 获取访问令牌并连接WebSocket
|
||||
val token = TokenManager.getAccessToken()
|
||||
Log.d(TAG, "开始对话,conversationId: $convId")
|
||||
webSocketClient.connect(
|
||||
conversationRealtime.connect(
|
||||
convId,
|
||||
token,
|
||||
onMessage = { message ->
|
||||
@@ -256,12 +230,9 @@ class CreateMemoryViewModel(
|
||||
Log.e(TAG, "WebSocket错误: $errorMsg")
|
||||
connectionStatus.value = "错误: $errorMsg"
|
||||
isRecording.value = false
|
||||
// 添加到错误列表(最多保留10条)
|
||||
errorMessages.value = (errorMessages.value + errorMsg).takeLast(10)
|
||||
}
|
||||
)
|
||||
|
||||
// 注意:连接状态将在收到服务器的connect消息时更新
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "开始对话失败: ${e.message}", e)
|
||||
connectionStatus.value = "连接失败: ${e.message}"
|
||||
@@ -273,8 +244,8 @@ class CreateMemoryViewModel(
|
||||
fun endConversation() {
|
||||
viewModelScope.launch {
|
||||
conversationId.value?.let { id ->
|
||||
webSocketClient.sendEndConversation(id)
|
||||
webSocketClient.disconnect()
|
||||
conversationRealtime.sendEndConversation(id)
|
||||
conversationRealtime.disconnect()
|
||||
connectionStatus.value = "已断开"
|
||||
wsIsConnected.value = false
|
||||
isRecording.value = false
|
||||
@@ -285,7 +256,7 @@ class CreateMemoryViewModel(
|
||||
fun sendAudioChunk(chunk: ByteArray) {
|
||||
viewModelScope.launch {
|
||||
conversationId.value?.let { id ->
|
||||
webSocketClient.sendAudioChunk(chunk, id)
|
||||
conversationRealtime.sendAudioChunk(chunk, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -344,13 +315,13 @@ class CreateMemoryViewModel(
|
||||
return@withLock null
|
||||
}
|
||||
|
||||
if (!webSocketClient.isConnected()) {
|
||||
if (!conversationRealtime.isConnected()) {
|
||||
Log.w(TAG, "分段发送前 WebSocket 未连接,开始等待重连")
|
||||
connectionStatus.value = "未连接,正在重连..."
|
||||
wsIsConnected.value = false
|
||||
val token = TokenManager.getAccessToken()
|
||||
try {
|
||||
webSocketClient.connect(
|
||||
conversationRealtime.connect(
|
||||
readyConversationId,
|
||||
token,
|
||||
onMessage = { message -> handleWebSocketMessage(message) },
|
||||
@@ -363,7 +334,7 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
|
||||
val connected = withTimeoutOrNull(SEGMENT_WAIT_TIMEOUT_MS) {
|
||||
if (webSocketClient.isConnected()) {
|
||||
if (conversationRealtime.isConnected()) {
|
||||
true
|
||||
} else {
|
||||
wsIsConnected.first { it }
|
||||
@@ -378,9 +349,6 @@ class CreateMemoryViewModel(
|
||||
readyConversationId
|
||||
}
|
||||
|
||||
/**
|
||||
* 开始录音(需 API 26+ 与录音权限)
|
||||
*/
|
||||
fun startRecordingVoice() {
|
||||
Log.d(TAG, "开始录音")
|
||||
when (val result = recordingCoordinator.start()) {
|
||||
@@ -570,12 +538,12 @@ class CreateMemoryViewModel(
|
||||
startConversation()
|
||||
delay(300)
|
||||
}
|
||||
if (!webSocketClient.isConnected()) {
|
||||
if (!conversationRealtime.isConnected()) {
|
||||
connectionStatus.value = "未连接,正在重连..."
|
||||
conversationId.value?.let { id ->
|
||||
val token = TokenManager.getAccessToken()
|
||||
try {
|
||||
webSocketClient.connect(
|
||||
conversationRealtime.connect(
|
||||
id,
|
||||
token,
|
||||
onMessage = { m -> handleWebSocketMessage(m) },
|
||||
@@ -594,7 +562,7 @@ class CreateMemoryViewModel(
|
||||
val id = conversationId.value ?: return@launch
|
||||
try {
|
||||
waitingForTranscribeOnly = true
|
||||
webSocketClient.sendTranscribeOnly(audioBytes, id)
|
||||
conversationRealtime.sendTranscribeOnly(audioBytes, id)
|
||||
val text = withTimeoutOrNull(15000L) { pendingTranscribeChannel.receive() }
|
||||
waitingForTranscribeOnly = false
|
||||
if (!text.isNullOrBlank() && !text.startsWith("转写失败")) {
|
||||
@@ -612,9 +580,6 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送分段语音(持久化待发队列 + 自动重试,不阻塞后续分段发送)
|
||||
*/
|
||||
private suspend fun dispatchPendingVoiceSegment(pendingSegment: PendingVoiceSegment) {
|
||||
if (!tryAcquirePendingDispatch(pendingSegment.clientSegmentId)) {
|
||||
return
|
||||
@@ -646,14 +611,16 @@ class CreateMemoryViewModel(
|
||||
var lastError: Exception? = null
|
||||
for (attempt in 1..SEGMENT_RETRY_MAX_ATTEMPTS) {
|
||||
try {
|
||||
webSocketClient.sendAudioSegment(
|
||||
audioBytes = segmentToSend.audioBytes,
|
||||
conversationId = segmentToSend.conversationId,
|
||||
voiceSessionId = segmentToSend.voiceSessionId,
|
||||
segmentIndex = segmentToSend.segmentIndex,
|
||||
duration = segmentToSend.durationSeconds,
|
||||
isLast = segmentToSend.isLast,
|
||||
clientSegmentId = segmentToSend.clientSegmentId,
|
||||
conversationRealtime.sendAudioSegment(
|
||||
AudioSegmentRequest(
|
||||
audioBytes = segmentToSend.audioBytes,
|
||||
conversationId = segmentToSend.conversationId,
|
||||
voiceSessionId = segmentToSend.voiceSessionId,
|
||||
segmentIndex = segmentToSend.segmentIndex,
|
||||
duration = segmentToSend.durationSeconds,
|
||||
isLast = segmentToSend.isLast,
|
||||
clientSegmentId = segmentToSend.clientSegmentId,
|
||||
)
|
||||
)
|
||||
pendingVoiceSegmentStore.remove(segmentToSend.clientSegmentId)
|
||||
Log.d(TAG, "分段语音发送成功: idx=${segmentToSend.segmentIndex}, attempt=$attempt")
|
||||
@@ -683,28 +650,22 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送音频消息
|
||||
*/
|
||||
private suspend fun sendAudioMessage(audioBytes: ByteArray, filePath: String, durationSeconds: Int) {
|
||||
Log.d(TAG, "准备发送音频消息,大小: ${audioBytes.size}, 时长: ${durationSeconds}s")
|
||||
|
||||
// 确保已连接(手势松开后马上发语音,尽量缩短等待)
|
||||
if (conversationId.value == null) {
|
||||
Log.d(TAG, "对话ID为空,开始创建新对话")
|
||||
startConversation()
|
||||
delay(300)
|
||||
}
|
||||
|
||||
// 检查连接状态
|
||||
if (!webSocketClient.isConnected()) {
|
||||
if (!conversationRealtime.isConnected()) {
|
||||
Log.w(TAG, "WebSocket未连接,尝试重连")
|
||||
connectionStatus.value = "未连接,正在重连..."
|
||||
// 重连逻辑类似 sendTextMessage
|
||||
conversationId.value?.let { id ->
|
||||
val token = TokenManager.getAccessToken()
|
||||
try {
|
||||
webSocketClient.connect(
|
||||
conversationRealtime.connect(
|
||||
id,
|
||||
token,
|
||||
onMessage = { message -> handleWebSocketMessage(message) },
|
||||
@@ -722,10 +683,8 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
|
||||
conversationId.value?.let { id ->
|
||||
// 生成临时消息 ID
|
||||
val tempMessageId = "audio_user_${System.currentTimeMillis()}"
|
||||
|
||||
// 添加到历史消息(本地先显示)
|
||||
val tempMessage = MessageDto(
|
||||
id = tempMessageId,
|
||||
conversationId = id,
|
||||
@@ -736,24 +695,18 @@ class CreateMemoryViewModel(
|
||||
)
|
||||
historyMessages.value = historyMessages.value + tempMessage
|
||||
|
||||
// 保存音频文件路径和时长
|
||||
_audioFilePaths.value = _audioFilePaths.value + (tempMessageId to filePath)
|
||||
_audioDurations.value = _audioDurations.value + (tempMessageId to durationSeconds)
|
||||
|
||||
try {
|
||||
// 显示加载动画
|
||||
isTyping.value = true
|
||||
|
||||
// 发送音频消息
|
||||
webSocketClient.sendAudioMessage(audioBytes, id, durationSeconds)
|
||||
conversationRealtime.sendAudioMessage(audioBytes, id, durationSeconds)
|
||||
Log.d(TAG, "音频消息发送成功")
|
||||
|
||||
} catch (e: Exception) {
|
||||
isTyping.value = false
|
||||
Log.e(TAG, "音频消息发送失败: ${e.message}", e)
|
||||
connectionStatus.value = "发送失败: ${e.message}"
|
||||
errorMessages.value = (errorMessages.value + "发送失败: ${e.message}").takeLast(10)
|
||||
// 移除临时消息
|
||||
historyMessages.value = historyMessages.value.filter { it.id != tempMessageId }
|
||||
_audioFilePaths.value = _audioFilePaths.value - tempMessageId
|
||||
_audioDurations.value = _audioDurations.value - tempMessageId
|
||||
@@ -763,17 +716,11 @@ class CreateMemoryViewModel(
|
||||
|
||||
// ==================== 音频播放功能 ====================
|
||||
|
||||
/**
|
||||
* 播放/暂停音频
|
||||
*/
|
||||
fun toggleAudioPlayback(messageId: String, filePath: String) {
|
||||
Log.d(TAG, "切换音频播放状态: $messageId, $filePath")
|
||||
audioPlayer.play(messageId, filePath)
|
||||
}
|
||||
|
||||
/**
|
||||
* 停止音频播放
|
||||
*/
|
||||
fun stopAudioPlayback() {
|
||||
audioPlayer.stop()
|
||||
}
|
||||
@@ -784,22 +731,19 @@ class CreateMemoryViewModel(
|
||||
Log.d(TAG, "准备发送文本消息: $text")
|
||||
|
||||
viewModelScope.launch {
|
||||
// 确保已连接
|
||||
if (conversationId.value == null) {
|
||||
Log.d(TAG, "对话ID为空,开始创建新对话")
|
||||
startConversation()
|
||||
// 等待连接建立
|
||||
delay(500)
|
||||
}
|
||||
|
||||
// 检查连接状态
|
||||
if (!webSocketClient.isConnected()) {
|
||||
if (!conversationRealtime.isConnected()) {
|
||||
Log.w(TAG, "WebSocket未连接,尝试重连")
|
||||
connectionStatus.value = "未连接,正在重连..."
|
||||
conversationId.value?.let { id ->
|
||||
val token = TokenManager.getAccessToken()
|
||||
try {
|
||||
webSocketClient.connect(
|
||||
conversationRealtime.connect(
|
||||
id,
|
||||
token,
|
||||
onMessage = { message ->
|
||||
@@ -811,7 +755,7 @@ class CreateMemoryViewModel(
|
||||
connectionStatus.value = "错误: $errorMsg"
|
||||
}
|
||||
)
|
||||
delay(500) // 等待连接建立
|
||||
delay(500)
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "重连失败: ${e.message}", e)
|
||||
connectionStatus.value = "重连失败: ${e.message}"
|
||||
@@ -826,10 +770,8 @@ class CreateMemoryViewModel(
|
||||
|
||||
conversationId.value?.let { id ->
|
||||
Log.d(TAG, "发送消息到对话: $id")
|
||||
// 添加到用户消息列表
|
||||
userMessages.value = userMessages.value + text
|
||||
|
||||
// 添加到历史消息(临时,等待服务器确认)
|
||||
val tempMessage = MessageDto(
|
||||
id = "temp_user_${System.currentTimeMillis()}",
|
||||
conversationId = id,
|
||||
@@ -840,19 +782,15 @@ class CreateMemoryViewModel(
|
||||
)
|
||||
historyMessages.value = historyMessages.value + tempMessage
|
||||
|
||||
// 发送文本消息
|
||||
try {
|
||||
// 立即显示加载动画
|
||||
isTyping.value = true
|
||||
webSocketClient.sendTextMessage(text, id)
|
||||
conversationRealtime.sendText(id, text)
|
||||
Log.d(TAG, "消息发送成功")
|
||||
} catch (e: Exception) {
|
||||
// 发送失败时隐藏加载动画
|
||||
isTyping.value = false
|
||||
Log.e(TAG, "消息发送失败: ${e.message}", e)
|
||||
connectionStatus.value = "发送失败: ${e.message}"
|
||||
errorMessages.value = (errorMessages.value + "发送失败: ${e.message}").takeLast(10)
|
||||
// 移除临时消息
|
||||
historyMessages.value = historyMessages.value.filter { it.id != tempMessage.id }
|
||||
}
|
||||
} ?: run {
|
||||
@@ -863,7 +801,6 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
|
||||
private fun handleWebSocketMessage(message: WebSocketMessage) {
|
||||
// 更新调试信息
|
||||
lastMessageType.value = message.type.name
|
||||
lastMessageTime.value = java.text.SimpleDateFormat("HH:mm:ss", java.util.Locale.getDefault()).format(java.util.Date())
|
||||
messageCount.value = messageCount.value + 1
|
||||
@@ -916,7 +853,6 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
MessageType.agent_response -> {
|
||||
// 处理Agent回复(可能有多条消息,每条作为单独气泡显示)
|
||||
val text = message.getString("text") ?: ""
|
||||
val isTransition = message.getBoolean("transition") == true
|
||||
val index = message.getInt("index") ?: 0
|
||||
@@ -927,12 +863,10 @@ class CreateMemoryViewModel(
|
||||
return
|
||||
}
|
||||
|
||||
// 收到第一条回复时,隐藏打字指示器
|
||||
if (index == 0) {
|
||||
isTyping.value = false
|
||||
}
|
||||
|
||||
// 每条消息立即作为单独的气泡添加到历史消息
|
||||
conversationId.value?.let { id ->
|
||||
val aiMessage = MessageDto(
|
||||
id = "ai_${System.currentTimeMillis()}_$index",
|
||||
@@ -945,42 +879,35 @@ class CreateMemoryViewModel(
|
||||
historyMessages.value = historyMessages.value + aiMessage
|
||||
}
|
||||
|
||||
// 更新 agentResponse(用于显示最新回复)
|
||||
if (index == 0) {
|
||||
agentResponse.value = text
|
||||
} else {
|
||||
agentResponse.value += "\n\n$text"
|
||||
}
|
||||
|
||||
// 如果是最后一条消息,结束流式状态
|
||||
if (index >= total - 1) {
|
||||
isStreaming.value = false
|
||||
streamingText.value = ""
|
||||
isTyping.value = false
|
||||
webSocketClient.setGenerating(false)
|
||||
conversationRealtime.setGenerating(false)
|
||||
}
|
||||
}
|
||||
MessageType.agent_response_start -> {
|
||||
// 流式回复开始
|
||||
isStreaming.value = true
|
||||
isTyping.value = false // 流式开始时隐藏打字指示器
|
||||
isTyping.value = false
|
||||
streamingText.value = ""
|
||||
webSocketClient.setGenerating(true)
|
||||
conversationRealtime.setGenerating(true)
|
||||
}
|
||||
MessageType.agent_response_chunk -> {
|
||||
// 流式回复片段
|
||||
val chunk = message.getString("text") ?: ""
|
||||
streamingText.value += chunk
|
||||
// 收到第一个片段时,隐藏打字指示器
|
||||
if (streamingText.value == chunk) {
|
||||
isTyping.value = false
|
||||
}
|
||||
}
|
||||
MessageType.agent_response_end -> {
|
||||
// 流式回复结束
|
||||
agentResponse.value = streamingText.value
|
||||
|
||||
// 添加到历史消息
|
||||
conversationId.value?.let { id ->
|
||||
val aiMessage = MessageDto(
|
||||
id = "ai_${System.currentTimeMillis()}",
|
||||
@@ -995,15 +922,12 @@ class CreateMemoryViewModel(
|
||||
|
||||
isStreaming.value = false
|
||||
streamingText.value = ""
|
||||
webSocketClient.setGenerating(false)
|
||||
conversationRealtime.setGenerating(false)
|
||||
}
|
||||
MessageType.agent_typing -> {
|
||||
// AI正在输入
|
||||
isTyping.value = true
|
||||
}
|
||||
MessageType.text -> {
|
||||
// 处理文本消息响应(如果需要)
|
||||
}
|
||||
MessageType.text -> { }
|
||||
MessageType.connect -> {
|
||||
Log.d(TAG, "收到连接确认消息,设置状态为已连接")
|
||||
connectionStatus.value = "已连接"
|
||||
@@ -1016,8 +940,7 @@ class CreateMemoryViewModel(
|
||||
connectionStatus.value = "对话已结束"
|
||||
isRecording.value = false
|
||||
isStreaming.value = false
|
||||
webSocketClient.setGenerating(false)
|
||||
// 触发对话结束后的处理
|
||||
conversationRealtime.setGenerating(false)
|
||||
handleConversationEnded()
|
||||
}
|
||||
MessageType.error -> {
|
||||
@@ -1027,19 +950,16 @@ class CreateMemoryViewModel(
|
||||
wsIsConnected.value = false
|
||||
isStreaming.value = false
|
||||
isTyping.value = false
|
||||
webSocketClient.setGenerating(false)
|
||||
conversationRealtime.setGenerating(false)
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 取消当前正在生成的回复
|
||||
*/
|
||||
fun cancelGeneration() {
|
||||
viewModelScope.launch {
|
||||
conversationId.value?.let { id ->
|
||||
webSocketClient.cancelGeneration(id)
|
||||
conversationRealtime.cancelGeneration(id)
|
||||
isStreaming.value = false
|
||||
streamingText.value = ""
|
||||
isTyping.value = false
|
||||
@@ -1047,26 +967,20 @@ class CreateMemoryViewModel(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理对话结束后的逻辑
|
||||
* 等待后台任务完成,然后刷新章节列表
|
||||
*/
|
||||
private fun handleConversationEnded() {
|
||||
viewModelScope.launch {
|
||||
isProcessing.value = true
|
||||
processingStatus.value = "正在处理回忆录..."
|
||||
|
||||
// 等待后台处理完成(最多等待3分钟)
|
||||
val maxWaitSeconds = 180
|
||||
val checkInterval = 5 // 每5秒检查一次
|
||||
val checkInterval = 5
|
||||
var elapsed = 0
|
||||
|
||||
while (elapsed < maxWaitSeconds) {
|
||||
delay(checkInterval * 1000L)
|
||||
elapsed += checkInterval
|
||||
|
||||
// 检查任务状态
|
||||
val tasksStatus = apiService.getTasksStatus()
|
||||
val tasksStatus = conversationApi.getTasksStatus()
|
||||
tasksStatus.fold(
|
||||
onSuccess = { status ->
|
||||
val total = status.total
|
||||
@@ -1078,7 +992,6 @@ class CreateMemoryViewModel(
|
||||
|
||||
processingStatus.value = "处理中: 总任务$total, 等待$pending, 运行$running, 成功$success, 失败$failure"
|
||||
|
||||
// 如果所有任务完成,刷新章节列表
|
||||
if (total > 0 && allCompleted) {
|
||||
processingStatus.value = "处理完成!"
|
||||
refreshChapters()
|
||||
@@ -1086,10 +999,8 @@ class CreateMemoryViewModel(
|
||||
return@launch
|
||||
}
|
||||
|
||||
// 如果没有任务但有章节内容,也认为完成
|
||||
if (total == 0 && elapsed > 30) {
|
||||
// 检查是否有章节
|
||||
val chaptersResult = apiService.getChapters()
|
||||
val chaptersResult = conversationApi.getChapters()
|
||||
chaptersResult.fold(
|
||||
onSuccess = { chapters ->
|
||||
if (chapters.isNotEmpty()) {
|
||||
@@ -1109,21 +1020,16 @@ class CreateMemoryViewModel(
|
||||
)
|
||||
}
|
||||
|
||||
// 超时后也刷新章节列表
|
||||
processingStatus.value = "处理超时,正在刷新章节..."
|
||||
refreshChapters()
|
||||
isProcessing.value = false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新章节列表
|
||||
*/
|
||||
private suspend fun refreshChapters() {
|
||||
val chaptersResult = apiService.getChapters()
|
||||
val chaptersResult = conversationApi.getChapters()
|
||||
chaptersResult.fold(
|
||||
onSuccess = { chapterDtos ->
|
||||
// 转换为本地Chapter实体并保存
|
||||
val chapters = chapterDtos.map { dto ->
|
||||
Chapter(
|
||||
id = dto.id,
|
||||
@@ -1154,7 +1060,7 @@ class CreateMemoryViewModel(
|
||||
super.onCleared()
|
||||
pendingSegmentRetryJob?.cancel()
|
||||
viewModelScope.launch {
|
||||
webSocketClient.disconnect()
|
||||
conversationRealtime.disconnect()
|
||||
}
|
||||
recordingCoordinator.release()
|
||||
audioPlayer.release()
|
||||
|
||||
@@ -3,83 +3,39 @@ package com.huaga.life_echo.ui.viewmodel
|
||||
import android.content.Context
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.ViewModelProvider
|
||||
import com.huaga.life_echo.config.AppConfig
|
||||
import com.huaga.life_echo.data.auth.TokenManager
|
||||
import com.huaga.life_echo.data.database.AppDatabase
|
||||
import com.huaga.life_echo.data.repository.*
|
||||
import com.huaga.life_echo.network.ApiService
|
||||
import com.huaga.life_echo.network.AuthService
|
||||
import com.huaga.life_echo.payment.PaymentManager
|
||||
import com.huaga.life_echo.app.AppContainer
|
||||
import com.huaga.life_echo.app.LifeEchoApp
|
||||
import com.huaga.life_echo.feature.conversation.adapters.ConversationApiAdapter
|
||||
import com.huaga.life_echo.feature.conversation.adapters.ConversationRealtimeAdapter
|
||||
import com.huaga.life_echo.feature.memoir.adapters.MemoirApiAdapter
|
||||
|
||||
class ViewModelFactory(private val context: Context) : ViewModelProvider.Factory {
|
||||
|
||||
init {
|
||||
// 初始化TokenManager
|
||||
TokenManager.initialize(context)
|
||||
}
|
||||
|
||||
private val database by lazy { AppDatabase.getDatabase(context) }
|
||||
private val authService by lazy { AuthService() }
|
||||
private val apiService by lazy {
|
||||
ApiService(
|
||||
tokenManager = TokenManager,
|
||||
authService = authService
|
||||
)
|
||||
}
|
||||
private val conversationRepository by lazy {
|
||||
ConversationRepository(
|
||||
conversationDao = database.conversationDao(),
|
||||
segmentDao = database.conversationSegmentDao(),
|
||||
apiService = apiService
|
||||
)
|
||||
}
|
||||
private val chapterRepository by lazy {
|
||||
ChapterRepository(
|
||||
chapterDao = database.chapterDao()
|
||||
)
|
||||
}
|
||||
private val messageRepository by lazy {
|
||||
MessageRepository(
|
||||
messageDao = database.messageDao(),
|
||||
apiService = apiService
|
||||
)
|
||||
}
|
||||
|
||||
private val paymentRepository by lazy {
|
||||
PaymentRepository(apiService = apiService)
|
||||
}
|
||||
|
||||
private val profileRepository by lazy {
|
||||
ProfileRepository(apiService = apiService)
|
||||
}
|
||||
|
||||
private val paymentManager by lazy {
|
||||
PaymentManager(
|
||||
context = context.applicationContext,
|
||||
wechatAppId = AppConfig.WECHAT_APP_ID
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
private val container: AppContainer
|
||||
get() = (context.applicationContext as LifeEchoApp).container
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <T : ViewModel> create(modelClass: Class<T>): T {
|
||||
return when {
|
||||
modelClass.isAssignableFrom(CreateMemoryViewModel::class.java) -> {
|
||||
CreateMemoryViewModel(
|
||||
conversationRepository = conversationRepository,
|
||||
chapterRepository = chapterRepository,
|
||||
messageRepository = messageRepository,
|
||||
context = context
|
||||
conversationRepository = container.conversationRepository,
|
||||
chapterRepository = container.chapterRepository,
|
||||
messageRepository = container.messageRepository,
|
||||
context = context,
|
||||
conversationApi = ConversationApiAdapter(container.apiService),
|
||||
conversationRealtime = ConversationRealtimeAdapter(container.webSocketClient),
|
||||
) as T
|
||||
}
|
||||
modelClass.isAssignableFrom(ConversationListViewModel::class.java) -> {
|
||||
ConversationListViewModel(
|
||||
conversationRepository = conversationRepository
|
||||
conversationRepository = container.conversationRepository
|
||||
) as T
|
||||
}
|
||||
modelClass.isAssignableFrom(MyMemoirViewModel::class.java) -> {
|
||||
MyMemoirViewModel(
|
||||
chapterRepository = chapterRepository,
|
||||
apiService = apiService
|
||||
chapterRepository = container.chapterRepository,
|
||||
memoirApi = container.memoirApi
|
||||
) as T
|
||||
}
|
||||
modelClass.isAssignableFrom(AuthViewModel::class.java) -> {
|
||||
@@ -87,17 +43,16 @@ class ViewModelFactory(private val context: Context) : ViewModelProvider.Factory
|
||||
}
|
||||
modelClass.isAssignableFrom(PaymentViewModel::class.java) -> {
|
||||
PaymentViewModel(
|
||||
paymentRepository = paymentRepository,
|
||||
paymentManager = paymentManager
|
||||
paymentRepository = container.paymentRepository,
|
||||
paymentManager = container.paymentManager
|
||||
) as T
|
||||
}
|
||||
modelClass.isAssignableFrom(ProfileViewModel::class.java) -> {
|
||||
ProfileViewModel(
|
||||
profileRepository = profileRepository
|
||||
profileRepository = container.profileRepository
|
||||
) as T
|
||||
}
|
||||
else -> throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,14 +12,23 @@ import com.huaga.life_echo.data.database.MessageDao
|
||||
import com.huaga.life_echo.data.repository.ChapterRepository
|
||||
import com.huaga.life_echo.data.repository.ConversationRepository
|
||||
import com.huaga.life_echo.data.repository.MessageRepository
|
||||
import com.huaga.life_echo.feature.conversation.ports.AudioSegmentRequest
|
||||
import com.huaga.life_echo.feature.conversation.ports.ConversationApiPort
|
||||
import com.huaga.life_echo.feature.conversation.ports.ConversationRealtimePort
|
||||
import com.huaga.life_echo.feature.voice.RecorderEngine
|
||||
import com.huaga.life_echo.feature.voice.RecorderStartResult
|
||||
import com.huaga.life_echo.feature.voice.RecorderStopResult
|
||||
import com.huaga.life_echo.feature.voice.RecordingCoordinator
|
||||
import com.huaga.life_echo.network.ApiService
|
||||
import com.huaga.life_echo.network.WebSocketMessage
|
||||
import com.huaga.life_echo.network.models.ChapterDto
|
||||
import com.huaga.life_echo.network.models.CreateConversationResponse
|
||||
import com.huaga.life_echo.network.models.MessageDto
|
||||
import com.huaga.life_echo.network.models.TasksStatusDto
|
||||
import com.huaga.life_echo.testutil.MainDispatcherRule
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.test.advanceUntilIdle
|
||||
import kotlinx.coroutines.test.runTest
|
||||
@@ -141,6 +150,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
private fun newViewModel(
|
||||
context: Context,
|
||||
recordingCoordinator: RecordingCoordinator,
|
||||
@@ -160,6 +170,8 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
|
||||
apiService = apiService,
|
||||
),
|
||||
context = context,
|
||||
conversationApi = NoOpConversationApiPort(),
|
||||
conversationRealtime = NoOpConversationRealtimePort(),
|
||||
recordingCoordinator = recordingCoordinator,
|
||||
tokenInitializer = {},
|
||||
)
|
||||
@@ -234,63 +246,77 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
|
||||
|
||||
private class FakeConversationDao : ConversationDao {
|
||||
override fun getAllConversations() = flowOf(emptyList<Conversation>())
|
||||
|
||||
override suspend fun getConversationById(id: String): Conversation? = null
|
||||
|
||||
override suspend fun getLatestEmptyConversation(): Conversation? = null
|
||||
|
||||
override suspend fun insertConversation(conversation: Conversation) = Unit
|
||||
|
||||
override suspend fun updateConversation(conversation: Conversation) = Unit
|
||||
|
||||
override suspend fun deleteConversation(conversation: Conversation) = Unit
|
||||
|
||||
override suspend fun deleteOtherEmptyConversations(keepId: String) = Unit
|
||||
}
|
||||
|
||||
private class FakeConversationSegmentDao : ConversationSegmentDao {
|
||||
override fun getSegmentsByConversationId(conversationId: String) =
|
||||
flowOf(emptyList<ConversationSegment>())
|
||||
|
||||
override suspend fun insertSegment(segment: ConversationSegment) = Unit
|
||||
|
||||
override suspend fun insertSegments(segments: List<ConversationSegment>) = Unit
|
||||
|
||||
override suspend fun updateSegment(segment: ConversationSegment) = Unit
|
||||
|
||||
override suspend fun deleteSegment(segment: ConversationSegment) = Unit
|
||||
}
|
||||
|
||||
private class FakeChapterDao : ChapterDao {
|
||||
override fun getAllChapters() = flowOf(emptyList<Chapter>())
|
||||
|
||||
override suspend fun getChapterById(id: String): Chapter? = null
|
||||
|
||||
override suspend fun getChaptersByCategory(category: String): List<Chapter> = emptyList()
|
||||
|
||||
override suspend fun insertChapter(chapter: Chapter) = Unit
|
||||
|
||||
override suspend fun insertChapters(chapters: List<Chapter>) = Unit
|
||||
|
||||
override suspend fun updateChapter(chapter: Chapter) = Unit
|
||||
|
||||
override suspend fun deleteChapter(chapter: Chapter) = Unit
|
||||
}
|
||||
|
||||
private class FakeMessageDao : MessageDao {
|
||||
override fun getMessagesByConversationId(conversationId: String) =
|
||||
flowOf(emptyList<Message>())
|
||||
|
||||
override suspend fun getMessageById(id: String): Message? = null
|
||||
|
||||
override suspend fun insertMessage(message: Message) = Unit
|
||||
|
||||
override suspend fun insertMessages(messages: List<Message>) = Unit
|
||||
|
||||
override suspend fun updateMessage(message: Message) = Unit
|
||||
|
||||
override suspend fun deleteMessage(message: Message) = Unit
|
||||
|
||||
override suspend fun deleteMessagesByConversationId(conversationId: String) = Unit
|
||||
}
|
||||
|
||||
private class NoOpConversationApiPort : ConversationApiPort {
|
||||
override suspend fun createConversation(): Result<CreateConversationResponse> =
|
||||
Result.failure(Exception("no-op"))
|
||||
override suspend fun getMessages(conversationId: String): Result<List<MessageDto>> =
|
||||
Result.success(emptyList())
|
||||
override suspend fun getTasksStatus(): Result<TasksStatusDto> =
|
||||
Result.failure(Exception("no-op"))
|
||||
override suspend fun getChapters(): Result<List<ChapterDto>> =
|
||||
Result.success(emptyList())
|
||||
override suspend fun clearTasks(): Result<Unit> =
|
||||
Result.success(Unit)
|
||||
}
|
||||
|
||||
private class NoOpConversationRealtimePort : ConversationRealtimePort {
|
||||
override val state: StateFlow<ConversationRealtimePort.State> =
|
||||
MutableStateFlow(ConversationRealtimePort.State.NotConnected)
|
||||
override suspend fun prepare() = Unit
|
||||
override suspend fun connect(
|
||||
conversationId: String,
|
||||
token: String?,
|
||||
onMessage: (WebSocketMessage) -> Unit,
|
||||
onError: ((String) -> Unit)?,
|
||||
) = Unit
|
||||
override suspend fun disconnect() = Unit
|
||||
override fun isConnected(): Boolean = false
|
||||
override suspend fun sendText(conversationId: String, text: String) = Unit
|
||||
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
|
||||
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit
|
||||
override suspend fun sendTranscribeOnly(audioBytes: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendEndConversation(conversationId: String) = Unit
|
||||
override suspend fun cancelGeneration(conversationId: String) = Unit
|
||||
override fun isGenerating(): Boolean = false
|
||||
override fun setGenerating(generating: Boolean) = Unit
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
package com.huaga.life_echo.ui.viewmodel
|
||||
|
||||
import android.content.Context
|
||||
import com.huaga.life_echo.data.database.Chapter
|
||||
import com.huaga.life_echo.data.database.ChapterDao
|
||||
import com.huaga.life_echo.data.database.Conversation
|
||||
import com.huaga.life_echo.data.database.ConversationDao
|
||||
import com.huaga.life_echo.data.database.ConversationSegment
|
||||
import com.huaga.life_echo.data.database.ConversationSegmentDao
|
||||
import com.huaga.life_echo.data.database.Message
|
||||
import com.huaga.life_echo.data.database.MessageDao
|
||||
import com.huaga.life_echo.data.repository.ChapterRepository
|
||||
import com.huaga.life_echo.data.repository.ConversationRepository
|
||||
import com.huaga.life_echo.data.repository.MessageRepository
|
||||
import com.huaga.life_echo.feature.conversation.ports.AudioSegmentRequest
|
||||
import com.huaga.life_echo.feature.conversation.ports.ConversationApiPort
|
||||
import com.huaga.life_echo.feature.conversation.ports.ConversationRealtimePort
|
||||
import com.huaga.life_echo.network.ApiService
|
||||
import com.huaga.life_echo.network.WebSocketMessage
|
||||
import com.huaga.life_echo.network.models.ChapterDto
|
||||
import com.huaga.life_echo.network.models.CreateConversationResponse
|
||||
import com.huaga.life_echo.network.models.MessageDto
|
||||
import com.huaga.life_echo.network.models.TasksStatusDto
|
||||
import com.huaga.life_echo.testutil.MainDispatcherRule
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.test.advanceUntilIdle
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import org.mockito.Mockito
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
class CreateMemoryViewModelWarmupTest {
|
||||
|
||||
@get:Rule
|
||||
val mainDispatcherRule = MainDispatcherRule()
|
||||
|
||||
@Test
|
||||
fun initialize_conversation_prepares_realtime_before_connect() =
|
||||
runTest(mainDispatcherRule.dispatcher.scheduler) {
|
||||
val rootDir = Files.createTempDirectory("warmup-test").toFile()
|
||||
val context = newContext(rootDir)
|
||||
val realtime = FakeConversationRealtimePort()
|
||||
val viewModel = newViewModel(context, realtime = realtime)
|
||||
|
||||
try {
|
||||
viewModel.initializeConversation("conversation-1")
|
||||
advanceUntilIdle()
|
||||
|
||||
assertEquals(listOf("prepare", "connect:conversation-1"), realtime.calls)
|
||||
} finally {
|
||||
rootDir.deleteRecursively()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun initialize_conversation_with_null_does_not_prepare() =
|
||||
runTest(mainDispatcherRule.dispatcher.scheduler) {
|
||||
val rootDir = Files.createTempDirectory("warmup-test").toFile()
|
||||
val context = newContext(rootDir)
|
||||
val realtime = FakeConversationRealtimePort()
|
||||
val viewModel = newViewModel(context, realtime = realtime)
|
||||
|
||||
try {
|
||||
viewModel.initializeConversation(null)
|
||||
advanceUntilIdle()
|
||||
|
||||
assertEquals(emptyList<String>(), realtime.calls)
|
||||
} finally {
|
||||
rootDir.deleteRecursively()
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
private fun newViewModel(
|
||||
context: Context,
|
||||
realtime: ConversationRealtimePort = FakeConversationRealtimePort(),
|
||||
): CreateMemoryViewModel {
|
||||
val apiService = ApiService()
|
||||
return CreateMemoryViewModel(
|
||||
conversationRepository = ConversationRepository(
|
||||
conversationDao = FakeConversationDao(),
|
||||
segmentDao = FakeConversationSegmentDao(),
|
||||
apiService = apiService,
|
||||
),
|
||||
chapterRepository = ChapterRepository(
|
||||
chapterDao = FakeChapterDao(),
|
||||
),
|
||||
messageRepository = MessageRepository(
|
||||
messageDao = FakeMessageDao(),
|
||||
apiService = apiService,
|
||||
),
|
||||
context = context,
|
||||
conversationApi = FakeConversationApiPort(),
|
||||
conversationRealtime = realtime,
|
||||
tokenInitializer = {},
|
||||
)
|
||||
}
|
||||
|
||||
private fun newContext(rootDir: File): Context {
|
||||
val filesDir = File(rootDir, "files").apply { mkdirs() }
|
||||
val cacheDir = File(rootDir, "cache").apply { mkdirs() }
|
||||
val context = Mockito.mock(Context::class.java)
|
||||
Mockito.`when`(context.filesDir).thenReturn(filesDir)
|
||||
Mockito.`when`(context.cacheDir).thenReturn(cacheDir)
|
||||
Mockito.`when`(context.applicationContext).thenReturn(context)
|
||||
return context
|
||||
}
|
||||
|
||||
private class FakeConversationRealtimePort : ConversationRealtimePort {
|
||||
val calls = mutableListOf<String>()
|
||||
override val state: StateFlow<ConversationRealtimePort.State> =
|
||||
MutableStateFlow(ConversationRealtimePort.State.NotConnected)
|
||||
|
||||
override suspend fun prepare() {
|
||||
calls += "prepare"
|
||||
}
|
||||
|
||||
override suspend fun connect(
|
||||
conversationId: String,
|
||||
token: String?,
|
||||
onMessage: (WebSocketMessage) -> Unit,
|
||||
onError: ((String) -> Unit)?,
|
||||
) {
|
||||
calls += "connect:$conversationId"
|
||||
}
|
||||
|
||||
override suspend fun disconnect() {
|
||||
calls += "disconnect"
|
||||
}
|
||||
|
||||
override fun isConnected(): Boolean = false
|
||||
override suspend fun sendText(conversationId: String, text: String) = Unit
|
||||
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
|
||||
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit
|
||||
override suspend fun sendTranscribeOnly(audioBytes: ByteArray, conversationId: String) = Unit
|
||||
override suspend fun sendEndConversation(conversationId: String) = Unit
|
||||
override suspend fun cancelGeneration(conversationId: String) = Unit
|
||||
override fun isGenerating(): Boolean = false
|
||||
override fun setGenerating(generating: Boolean) = Unit
|
||||
}
|
||||
|
||||
private class FakeConversationApiPort : ConversationApiPort {
|
||||
override suspend fun createConversation(): Result<CreateConversationResponse> =
|
||||
Result.failure(Exception("no-op"))
|
||||
override suspend fun getMessages(conversationId: String): Result<List<MessageDto>> =
|
||||
Result.failure(Exception("test-offline"))
|
||||
override suspend fun getTasksStatus(): Result<TasksStatusDto> =
|
||||
Result.failure(Exception("no-op"))
|
||||
override suspend fun getChapters(): Result<List<ChapterDto>> =
|
||||
Result.success(emptyList())
|
||||
override suspend fun clearTasks(): Result<Unit> = Result.success(Unit)
|
||||
}
|
||||
|
||||
private class FakeConversationDao : ConversationDao {
|
||||
override fun getAllConversations() = flowOf(emptyList<Conversation>())
|
||||
override suspend fun getConversationById(id: String): Conversation? = null
|
||||
override suspend fun getLatestEmptyConversation(): Conversation? = null
|
||||
override suspend fun insertConversation(conversation: Conversation) = Unit
|
||||
override suspend fun updateConversation(conversation: Conversation) = Unit
|
||||
override suspend fun deleteConversation(conversation: Conversation) = Unit
|
||||
override suspend fun deleteOtherEmptyConversations(keepId: String) = Unit
|
||||
}
|
||||
|
||||
private class FakeConversationSegmentDao : ConversationSegmentDao {
|
||||
override fun getSegmentsByConversationId(conversationId: String) =
|
||||
flowOf(emptyList<ConversationSegment>())
|
||||
override suspend fun insertSegment(segment: ConversationSegment) = Unit
|
||||
override suspend fun insertSegments(segments: List<ConversationSegment>) = Unit
|
||||
override suspend fun updateSegment(segment: ConversationSegment) = Unit
|
||||
override suspend fun deleteSegment(segment: ConversationSegment) = Unit
|
||||
}
|
||||
|
||||
private class FakeChapterDao : ChapterDao {
|
||||
override fun getAllChapters() = flowOf(emptyList<Chapter>())
|
||||
override suspend fun getChapterById(id: String): Chapter? = null
|
||||
override suspend fun getChaptersByCategory(category: String): List<Chapter> = emptyList()
|
||||
override suspend fun insertChapter(chapter: Chapter) = Unit
|
||||
override suspend fun insertChapters(chapters: List<Chapter>) = Unit
|
||||
override suspend fun updateChapter(chapter: Chapter) = Unit
|
||||
override suspend fun deleteChapter(chapter: Chapter) = Unit
|
||||
}
|
||||
|
||||
private class FakeMessageDao : MessageDao {
|
||||
override fun getMessagesByConversationId(conversationId: String) =
|
||||
flowOf(emptyList<Message>())
|
||||
override suspend fun getMessageById(id: String): Message? = null
|
||||
override suspend fun insertMessage(message: Message) = Unit
|
||||
override suspend fun insertMessages(messages: List<Message>) = Unit
|
||||
override suspend fun updateMessage(message: Message) = Unit
|
||||
override suspend fun deleteMessage(message: Message) = Unit
|
||||
override suspend fun deleteMessagesByConversationId(conversationId: String) = Unit
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user