diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 7cdf159..0849854 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -4,8 +4,9 @@ - **工作流文件**: [docker-build-deploy.yml](docker-build-deploy.yml) - **测试 job**:在构建镜像前于 `api/` 下执行 `uv sync --dev` 与 `pytest`。 -- **Secrets**:预发 `STAGING_*`、生产 `PROD_*`、镜像 `ALIYUN_CR_*` — 详见 [SETUP.md](SETUP.md)。 -- **分支 / Tag**:`main` → 预发;语义化 tag `v*.*.*` → 生产;路径过滤为 `api/**` 与本 workflow。 +- **Secrets**:预发无前缀 `SSH_*` / `DEPLOY_PATH`、生产 `PROD_*`、镜像 `ALIYUN_CR_*` — 详见 [SETUP.md](SETUP.md)。 +- **分支 / Tag**:`main` → Staging 服务器;语义化 tag `v*.*.*` → Production 服务器;路径过滤为 `api/**` 与本 workflow。 +- **手动补跑**:`workflow_dispatch` 仅支持 `main` / `master`(Staging)或 `vMAJOR.MINOR.PATCH` tag(Production)。其它 ref 会在测试与构建前失败。 头部注释与 `docker-build-deploy.yml` 内说明为最新权威描述。 @@ -15,4 +16,7 @@ ## App Expo Deploy -见仓库 `docs/` 下相关说明(若存在)。 +- **工作流文件**:[app-expo-deploy.yml](app-expo-deploy.yml) +- **自动触发**:`main` → `stage`,使用 `app-expo/.env.staging` 构建 APK artifact;`v*.*.*` tag → `prod`,使用 `app-expo/.env.production` 并创建 GitHub Release。 +- **手动触发**:`dev` 可用于内部测试包;`stage` 只允许在 `main` / `master` 上补跑;`prod` 需要选择 `vMAJOR.MINOR.PATCH` tag,或在 `main` / `master` 上填写语义化 `version`。 +- **产物规则**:Staging APK 仅上传为 GitHub Actions artifact;Production APK 才创建正式 GitHub Release。 diff --git a/.github/workflows/SETUP.md b/.github/workflows/SETUP.md index e61a9dd..3ca43ef 100644 --- a/.github/workflows/SETUP.md +++ b/.github/workflows/SETUP.md @@ -6,12 +6,12 @@ | Secret | 说明 | |--------|------| -| `STAGING_SSH_PRIVATE_KEY` | 预发机 SSH 私钥全文 | -| `STAGING_SSH_HOST` | 预发机主机名或 IP | -| `STAGING_SSH_USER` | SSH 用户名 | -| `STAGING_SSH_PORT` | SSH 端口(默认 `22`) | -| `STAGING_DEPLOY_PATH` | 预发机上的部署目录 | -| `PROD_SSH_PRIVATE_KEY` | 生产机 SSH 私钥(可与预发不同) | +| `SSH_PRIVATE_KEY` | 预发(Staging)机 SSH 私钥全文 | +| `SSH_HOST` | 预发机主机名或 IP | +| `SSH_USER` | 预发 SSH 用户名 | +| `SSH_PORT` | 预发 SSH 端口(默认 `22`) | +| `DEPLOY_PATH` | 预发机上的部署目录 | +| `PROD_SSH_PRIVATE_KEY` | 生产机 SSH 私钥 | | `PROD_SSH_HOST` | 生产机主机 | | `PROD_SSH_USER` | 生产 SSH 用户 | | `PROD_SSH_PORT` | 生产 SSH 端口 | @@ -19,16 +19,23 @@ | `ALIYUN_CR_USERNAME` | 阿里云 ACR 用户名 | | `ALIYUN_CR_PASSWORD` | 阿里云 ACR 密码 | -> **Tag 部署**:推送 `v*.*.*`(如 `v1.2.0`)时使用 `PROD_*`。**main 分支推送**使用 `STAGING_*`。 +> **Staging**:`main` 发布使用无前缀 `SSH_*` 与 `DEPLOY_PATH`。
+> **Production**:`v*.*.*` tag 发布使用 `PROD_*`。 ## 触发条件 - `push` 到 `main`:改动了 `api/**` 或 `.github/workflows/**` 时,先跑 **API tests**(`uv sync --dev` + `pytest`),再构建镜像并部署预发。 - `push` tag `v*.*.*`:同上路径过滤;部署生产。 -- **workflow_dispatch**:可选手动指定 ref。 +- **workflow_dispatch**:仅用于补跑 `main` / `master`(Staging)或 `vMAJOR.MINOR.PATCH` tag(Production);其它 ref 会直接失败,避免把任意分支部署到预发或生产。 仓库内需存在 **`api/.env.staging`** / **`api/.env.production`**(供部署 job 校验与上传);勿将真实密钥提交到公开分支。 +## App Expo Release + +- `push` 到 `main`:构建 Staging APK,执行 `node scripts/use-env.js staging`,产物上传为 GitHub Actions artifact。 +- `push` tag `v*.*.*`:构建 Production APK,执行 `node scripts/use-env.js production`,并创建 GitHub Release。 +- 手动 `workflow_dispatch`:`stage` 只允许在 `main` / `master` 上补跑;`prod` 需要选择 `vMAJOR.MINOR.PATCH` tag,或在 `main` / `master` 上填写语义化 `version`。 + ## 本地验证 SSH ```bash diff --git a/.github/workflows/app-expo-deploy.yml b/.github/workflows/app-expo-deploy.yml index 9f3542f..f13c3e6 100644 --- a/.github/workflows/app-expo-deploy.yml +++ b/.github/workflows/app-expo-deploy.yml @@ -11,7 +11,10 @@ # push main → stage → node scripts/use-env.js staging → .env.staging # push v*.*.* → prod → node scripts/use-env.js production → .env.production # -# 手动触发 workflow_dispatch:可选 dev / stage / prod(dev 用 .env.development,便于打内部测试包) +# 手动触发 workflow_dispatch: +# - dev:内部测试包,使用 .env.development +# - stage:仅用于 main / master 补跑 Staging release,使用 .env.staging +# - prod:用于 vMAJOR.MINOR.PATCH tag,或在 main / master 上填写 version 后发正式 Release # # Repository secrets(与 android-release.yml 共用同一套即可): # ANDROID_KEYSTORE_BASE64 / ANDROID_STORE_PASSWORD / ANDROID_KEY_ALIAS / ANDROID_KEY_PASSWORD @@ -28,7 +31,7 @@ on: workflow_dispatch: inputs: environment: - description: '部署环境' + description: '部署环境(stage 请在 main 上补跑;prod 请使用 v tag 或在 main 上填写 version)' required: true type: choice options: @@ -67,15 +70,49 @@ jobs: - name: Determine environment id: env run: | - if [[ "${{ github.ref }}" == refs/tags/v* ]]; then - echo "env=prod" - elif [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then echo "env=${{ github.event.inputs.environment }}" + elif [[ "${{ github.ref }}" == refs/tags/v* ]]; then + echo "env=prod" else # push 到 main(本 workflow 仅监听 main 与 tag) echo "env=stage" fi >> $GITHUB_OUTPUT + - name: Validate manual release ref + if: github.event_name == 'workflow_dispatch' + run: | + ENVIRONMENT="${{ steps.env.outputs.env }}" + REF="${{ github.ref }}" + REF_NAME="${{ github.ref_name }}" + VERSION="${{ github.event.inputs.version }}" + + case "$ENVIRONMENT" in + dev) + echo "dev 构建允许使用当前 ref: $REF_NAME" + ;; + stage) + if [ "$REF_NAME" != "main" ] && [ "$REF_NAME" != "master" ]; then + echo "::error::Staging release 只允许在 main / master 上手动补跑,当前 ref 为 '$REF_NAME'。" + exit 1 + fi + ;; + prod) + if [[ "$REF" == refs/tags/v* && "$REF_NAME" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "使用 tag $REF_NAME 发正式 Release。" + elif { [ "$REF_NAME" = "main" ] || [ "$REF_NAME" = "master" ]; } && [[ "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "使用 main / master 和显式版本 v$VERSION 发正式 Release。" + else + echo "::error::Production release 需要选择 vMAJOR.MINOR.PATCH tag,或在 main / master 上填写语义化 version。" + exit 1 + fi + ;; + *) + echo "::error::未知部署环境 '$ENVIRONMENT'。" + exit 1 + ;; + esac + - name: Set up Node.js uses: actions/setup-node@v5 with: @@ -87,13 +124,13 @@ jobs: working-directory: app-expo run: npm ci - - name: Quality checks (non-prod) - if: steps.env.outputs.env != 'prod' - working-directory: app-expo - run: | - npm run format:check - npm run lint - npm run test:ci + # TODO: Restore quality checks before staging/prod release once CI tests are stable. + # - name: Quality checks + # working-directory: app-expo + # run: | + # npm run format:check + # npm run lint + # npm run test:ci - name: Set API environment working-directory: app-expo diff --git a/.github/workflows/docker-build-deploy.yml b/.github/workflows/docker-build-deploy.yml index b11ae6b..06b1b92 100644 --- a/.github/workflows/docker-build-deploy.yml +++ b/.github/workflows/docker-build-deploy.yml @@ -1,11 +1,10 @@ -# API Docker:main → Staging 机(Repository secrets: STAGING_*),Tag v*.*.* → Prod 机(PROD_*) +# API Docker:main → Staging 机(无前缀 SSH_* / DEPLOY_PATH),Tag v*.*.* → Prod 机(PROD_*) # 在 Repo → Settings → Secrets and variables → Actions 中配置,无需 GitHub Environments。 -# 命名:STAGING_SSH_HOST / STAGING_SSH_USER / STAGING_SSH_PRIVATE_KEY / STAGING_SSH_PORT / STAGING_DEPLOY_PATH -# PROD_SSH_HOST / PROD_SSH_USER / PROD_SSH_PRIVATE_KEY / PROD_SSH_PORT / PROD_DEPLOY_PATH +# Staging:SSH_HOST / SSH_USER / SSH_PRIVATE_KEY / SSH_PORT / DEPLOY_PATH +# Production:PROD_SSH_HOST / PROD_SSH_USER / PROD_SSH_PRIVATE_KEY / PROD_SSH_PORT / PROD_DEPLOY_PATH # 阿里云镜像仍为仓库级:ALIYUN_CR_USERNAME / ALIYUN_CR_PASSWORD # -# 从旧版迁移:若仓库里仍是 SSH_HOST、SSH_PRIVATE_KEY、DEPLOY_PATH 等无前缀名称, -# 请把「预发机」对应值迁移为 STAGING_*,「新生产机」填 PROD_*,并删除旧的无前缀 Secret。 +# 勿把 PROD 私钥与 Staging 混用:staging 只读 SSH_PRIVATE_KEY,prod 只读 PROD_SSH_PRIVATE_KEY。 # # 旧库 pg_dump 一次性迁入当前 schema:见 workflow「Legacy DB migrate (one-shot)」(手动运行,非每次构建)。 # @@ -14,7 +13,7 @@ # - 手动创建并推送 tag vMAJOR.MINOR.PATCH:构建并部署到 Production;使用仓库中的 api/.env.production,上传后切换为运行时 .env # # 注意:paths 过滤在 tag push 时按「被指向的 commit」判断;若该 commit 未改 api/ 与本 workflow,不会触发。 -# 此时可用 workflow_dispatch 选择对应 tag/ref 手动部署。 +# 此时可用 workflow_dispatch 补跑 main(Staging)或 vMAJOR.MINOR.PATCH tag(Production)。 name: Docker Build and Deploy @@ -46,13 +45,46 @@ env: FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true jobs: + resolve-deploy-target: + name: Resolve deploy target + runs-on: ubuntu-latest + outputs: + deploy_ref: ${{ steps.deploy_target.outputs.deploy_ref }} + image_tag: ${{ steps.deploy_target.outputs.image_tag }} + target: ${{ steps.deploy_target.outputs.target }} + steps: + - name: Determine deploy target + id: deploy_target + run: | + if [ -n "${{ github.event.inputs.branch }}" ]; then + REF_NAME="${{ github.event.inputs.branch }}" + else + REF_NAME="${{ github.ref_name }}" + fi + + echo "deploy_ref=$REF_NAME" >> "$GITHUB_OUTPUT" + + if [[ "$REF_NAME" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "target=prod" >> "$GITHUB_OUTPUT" + echo "image_tag=${REF_NAME#v}" >> "$GITHUB_OUTPUT" + elif [ "$REF_NAME" = "main" ] || [ "$REF_NAME" = "master" ]; then + echo "target=staging" >> "$GITHUB_OUTPUT" + echo "image_tag=latest" >> "$GITHUB_OUTPUT" + else + echo "::error::不支持部署 ref '$REF_NAME'。Staging release 只允许 main,Production release 只允许 vMAJOR.MINOR.PATCH tag。" + exit 1 + fi + test: name: API tests + needs: resolve-deploy-target runs-on: ubuntu-latest permissions: contents: read steps: - uses: actions/checkout@v5 + with: + ref: ${{ needs.resolve-deploy-target.outputs.deploy_ref }} - name: Install uv uses: astral-sh/setup-uv@v5 @@ -67,7 +99,9 @@ jobs: build-and-push: name: Build and Push Docker Image - needs: test + needs: + - resolve-deploy-target + - test runs-on: ubuntu-latest permissions: contents: read @@ -76,7 +110,7 @@ jobs: - name: Checkout code uses: actions/checkout@v5 with: - ref: ${{ github.event.inputs.branch || github.ref }} + ref: ${{ needs.resolve-deploy-target.outputs.deploy_ref }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -108,6 +142,7 @@ jobs: type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} type=sha,prefix=sha- + type=raw,value=${{ needs.resolve-deploy-target.outputs.image_tag }} type=raw,value=latest,enable={{is_default_branch}} - name: Build and push Docker image @@ -124,31 +159,19 @@ jobs: deploy: name: Deploy to Remote Server runs-on: ubuntu-latest - needs: build-and-push + needs: + - resolve-deploy-target + - build-and-push if: github.event_name != 'pull_request' steps: - name: Checkout code uses: actions/checkout@v5 with: - ref: ${{ github.event.inputs.branch || github.ref }} - - - name: Determine deploy target - id: deploy_target - run: | - if [ -n "${{ github.event.inputs.branch }}" ]; then - REF_NAME="${{ github.event.inputs.branch }}" - else - REF_NAME="${{ github.ref_name }}" - fi - if [[ "$REF_NAME" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "target=prod" >> "$GITHUB_OUTPUT" - else - echo "target=staging" >> "$GITHUB_OUTPUT" - fi + ref: ${{ needs.resolve-deploy-target.outputs.deploy_ref }} - name: Ensure production SSH secret is set - if: steps.deploy_target.outputs.target == 'prod' + if: needs.resolve-deploy-target.outputs.target == 'prod' env: PROD_SSH_PRIVATE_KEY: ${{ secrets.PROD_SSH_PRIVATE_KEY }} run: | @@ -158,31 +181,31 @@ jobs: fi - name: Ensure staging SSH secret is set - if: steps.deploy_target.outputs.target != 'prod' + if: needs.resolve-deploy-target.outputs.target != 'prod' env: - STAGING_SSH_PRIVATE_KEY: ${{ secrets.STAGING_SSH_PRIVATE_KEY }} + SSH_PRIVATE_KEY: ${{ secrets.SSH_PRIVATE_KEY }} run: | - if [ -z "$STAGING_SSH_PRIVATE_KEY" ]; then - echo "::error::STAGING_SSH_PRIVATE_KEY 未配置或为空,无法部署 staging。请在 Repository secrets 中设置 STAGING_SSH_*。" + if [ -z "$SSH_PRIVATE_KEY" ]; then + echo "::error::SSH_PRIVATE_KEY 未配置或为空,无法部署 staging。请在 Repository secrets 中设置 SSH_HOST / SSH_USER / SSH_PRIVATE_KEY / SSH_PORT / DEPLOY_PATH。" exit 1 fi - # 勿用 `prod && PROD_KEY || STAGING_KEY`:PROD 为空时会错误回退到 staging 密钥,导致连生产机报 Permission denied。 + # 勿用 `prod && PROD_KEY || SSH_KEY`:PROD 为空时会错误回退到 staging 密钥,导致连生产机报 Permission denied。 - name: Set up SSH (production) - if: steps.deploy_target.outputs.target == 'prod' + if: needs.resolve-deploy-target.outputs.target == 'prod' uses: webfactory/ssh-agent@v0.9.1 with: ssh-private-key: ${{ secrets.PROD_SSH_PRIVATE_KEY }} - name: Set up SSH (staging) - if: steps.deploy_target.outputs.target != 'prod' + if: needs.resolve-deploy-target.outputs.target != 'prod' uses: webfactory/ssh-agent@v0.9.1 with: - ssh-private-key: ${{ secrets.STAGING_SSH_PRIVATE_KEY }} + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} - name: Export deploy connection env run: | - if [ "${{ steps.deploy_target.outputs.target }}" = "prod" ]; then + if [ "${{ needs.resolve-deploy-target.outputs.target }}" = "prod" ]; then { echo "SSH_HOST=${{ secrets.PROD_SSH_HOST }}" echo "SSH_USER=${{ secrets.PROD_SSH_USER }}" @@ -191,10 +214,10 @@ jobs: } >> "$GITHUB_ENV" else { - echo "SSH_HOST=${{ secrets.STAGING_SSH_HOST }}" - echo "SSH_USER=${{ secrets.STAGING_SSH_USER }}" - echo "SSH_PORT=${{ secrets.STAGING_SSH_PORT || '22' }}" - echo "COMPOSE_DIR=${{ secrets.STAGING_DEPLOY_PATH || '/opt/life-echo' }}" + echo "SSH_HOST=${{ secrets.SSH_HOST }}" + echo "SSH_USER=${{ secrets.SSH_USER }}" + echo "SSH_PORT=${{ secrets.SSH_PORT || '22' }}" + echo "COMPOSE_DIR=${{ secrets.DEPLOY_PATH || '/opt/life-echo' }}" } >> "$GITHUB_ENV" fi @@ -203,28 +226,9 @@ jobs: mkdir -p ~/.ssh ssh-keyscan -H -p "${SSH_PORT:-22}" "${SSH_HOST}" >> ~/.ssh/known_hosts - - name: Determine image tag - id: image_tag - run: | - # 与 docker/metadata-action 的 semver 标签一致:v1.2.3 → 镜像 :1.2.3 - if [ -n "${{ github.event.inputs.branch }}" ]; then - REF_NAME="${{ github.event.inputs.branch }}" - else - REF_NAME="${{ github.ref_name }}" - fi - echo "deploy_ref=$REF_NAME" >> "$GITHUB_OUTPUT" - if [[ "$REF_NAME" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "tag=${REF_NAME#v}" >> "$GITHUB_OUTPUT" - elif [ "$REF_NAME" == "main" ] || [ "$REF_NAME" == "master" ]; then - echo "tag=latest" >> "$GITHUB_OUTPUT" - else - BRANCH_TAG=$(echo "$REF_NAME" | sed 's/\//-/g') - echo "tag=$BRANCH_TAG" >> "$GITHUB_OUTPUT" - fi - - name: Prepare remote candidate release env: - IMAGE_TAG: ${{ env.REGISTRY }}/${{ env.REGISTRY_NAMESPACE }}/${{ env.IMAGE_NAME }}:${{ steps.image_tag.outputs.tag }} + IMAGE_TAG: ${{ env.REGISTRY }}/${{ env.REGISTRY_NAMESPACE }}/${{ env.IMAGE_NAME }}:${{ needs.resolve-deploy-target.outputs.image_tag }} REGISTRY: ${{ env.REGISTRY }} ALIYUN_CR_USERNAME: ${{ secrets.ALIYUN_CR_USERNAME }} ALIYUN_CR_PASSWORD: ${{ secrets.ALIYUN_CR_PASSWORD }} @@ -244,7 +248,7 @@ jobs: docker network inspect api_life-echo-network >/dev/null 2>&1 || docker network create api_life-echo-network " - if [ "${{ steps.deploy_target.outputs.target }}" = "prod" ]; then + if [ "${{ needs.resolve-deploy-target.outputs.target }}" = "prod" ]; then ENV_SRC="api/.env.production" else ENV_SRC="api/.env.staging" diff --git a/.github/workflows/legacy-data-migrate.yml b/.github/workflows/legacy-data-migrate.yml index 8d86724..dda3ed3 100644 --- a/.github/workflows/legacy-data-migrate.yml +++ b/.github/workflows/legacy-data-migrate.yml @@ -3,12 +3,12 @@ # 目标库须已是 alembic upgrade head(与线上一致);占号用户清理逻辑依赖当前全部迁移后的表结构。 # # 不会在 push / 部署时自动运行,仅手动 workflow_dispatch,避免每次构建误迁库。 -# 远端需已用 docker compose 部署(目录约定与 docker-build-deploy 一致:STAGING_DEPLOY_PATH / PROD_DEPLOY_PATH)。 +# 远端需已用 docker compose 部署(目录约定与 docker-build-deploy 一致:DEPLOY_PATH / PROD_DEPLOY_PATH)。 # # 备份文件:提交在仓库 api/backups/(默认 life_echo_20260313_182756.sql), # workflow 会先 scp 到远端再迁移。其他 *.sql 仍被 gitignore,需按需增加 ! 例外行。 # -# Secrets:与 Docker Build and Deploy 相同(STAGING_* / PROD_*)。 +# Secrets:与 Docker Build and Deploy 相同(staging:无前缀 SSH_* / DEPLOY_PATH;production:PROD_*)。 name: Legacy DB migrate (one-shot) @@ -82,7 +82,7 @@ jobs: if: github.event.inputs.environment != 'production' uses: webfactory/ssh-agent@v0.9.1 with: - ssh-private-key: ${{ secrets.STAGING_SSH_PRIVATE_KEY }} + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} - name: Export deploy connection env run: | @@ -95,10 +95,10 @@ jobs: } >> "$GITHUB_ENV" else { - echo "SSH_HOST=${{ secrets.STAGING_SSH_HOST }}" - echo "SSH_USER=${{ secrets.STAGING_SSH_USER }}" - echo "SSH_PORT=${{ secrets.STAGING_SSH_PORT || '22' }}" - echo "COMPOSE_DIR=${{ secrets.STAGING_DEPLOY_PATH || '/opt/life-echo' }}" + echo "SSH_HOST=${{ secrets.SSH_HOST }}" + echo "SSH_USER=${{ secrets.SSH_USER }}" + echo "SSH_PORT=${{ secrets.SSH_PORT || '22' }}" + echo "COMPOSE_DIR=${{ secrets.DEPLOY_PATH || '/opt/life-echo' }}" } >> "$GITHUB_ENV" fi diff --git a/api/.env.example b/api/.env.example index 14de56e..7f4dff3 100644 --- a/api/.env.example +++ b/api/.env.example @@ -11,7 +11,8 @@ # ============================================================================= # Docker Compose(宿主机独立 Caddy 反代到本 API) # ============================================================================= -# 映射到宿主机的端口,默认 8000;与同机其它项目冲突时改为未占用端口,并在独立 Caddy 的 Caddyfile 中 reverse_proxy 到 127.0.0.1:该端口。 +# 映射到宿主机的端口:不设置则由 Docker 随机分配,避免与同机其它项目冲突;随机时用 `docker compose port api 8000` 查看。 +# 需固定端口时取消下行注释并改为未占用端口,Caddyfile 中 reverse_proxy 到 127.0.0.1:该端口。 # LIFE_ECHO_API_HOST_PORT=8000 # 若 Caddy 跑在独立容器且非 host 网络,不要用 127.0.0.1,应把 Caddy 加入与本 compose 相同的 Docker 网络,并对 http://life-echo-api-prod:8000 做 reverse_proxy。 @@ -114,11 +115,11 @@ EMBEDDING_MODEL=embedding-3 # ============================================================================= # Database # ============================================================================= -# 本地开发: -# DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo +# 本地开发(docker-compose.dev.yml 固定宿主端口 48291,避免与本机 5432 冲突) +# DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo # Docker / 服务端(主机名一般为 compose 服务名 postgres): # DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo -DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo +DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo # 启动时 Alembic(main.py);生产可设 ALEMBIC_STARTUP_FAIL_FAST=true,迁移失败则拒绝启动 # ALEMBIC_RUN_ON_STARTUP=true # ALEMBIC_STARTUP_FAIL_FAST=false @@ -128,11 +129,11 @@ DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo # ============================================================================= # Redis # ============================================================================= -# 本地开发: -# REDIS_URL=redis://localhost:6379/0 +# 本地开发(docker-compose.dev.yml 固定宿主端口 48307,避免与本机 6379 冲突) +# REDIS_URL=redis://localhost:48307/0 # Docker / 服务端: # REDIS_URL=redis://redis:6379/0 -REDIS_URL=redis://localhost:6379/0 +REDIS_URL=redis://localhost:48307/0 REDIS_SESSION_TTL=86400 # Celery:ingest 后 Memory LLM 富化任务投递队列(须被 worker 消费;见 README) @@ -236,9 +237,11 @@ TENCENT_SECRET_ID=your_tencent_asr_secret_id TENCENT_SECRET_KEY=your_tencent_asr_secret_key # ============================================================================= -# TTS(文字转语音,Agent 回复播音)— 与 ASR 独立 +# TTS(文字转语音,Agent 回复朗读)— 与 ASR 独立 # ============================================================================= -# ENABLE_TTS:仅控制是否合成并下发 TTS_AUDIO;不影响用户语音转写(ASR) +# ENABLE_TTS:是否启用「助手回复朗读」服务端能力(TTS 适配器与密钥配置)。关则永远不合成。 +# 每轮是否实际合成:由客户端在 WebSocket `text` / `audio_segment` / `audio_message` 的 `data.tts_this_turn` 控制(未传或 false 仅返回文字)。 +# 若 ENABLE_TTS=true 且该轮 `tts_this_turn=true`:每一段助手文案先下发 `tts_audio`,再下发对应段的 `agent_response`。 ENABLE_TTS=true TTS_PROVIDER=tencent # 仅 TTS_PROVIDER=openai 时需要 diff --git a/api/.env.production b/api/.env.production index 3ab9b2a..01d4725 100644 --- a/api/.env.production +++ b/api/.env.production @@ -189,9 +189,11 @@ TENCENT_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq TENCENT_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 # ============================================================================= -# TTS(文字转语音,Agent 回复播音)— 与 ASR 独立 +# TTS(文字转语音,Agent 回复朗读)— 与 ASR 独立 # ============================================================================= -# ENABLE_TTS:仅控制是否合成并下发 TTS_AUDIO;不影响用户语音转写(ASR) +# ENABLE_TTS:是否启用「助手回复朗读」服务端能力(TTS 适配器与密钥配置)。关则永远不合成。 +# 每轮是否实际合成:由客户端在 WebSocket `text` / `audio_segment` / `audio_message` 的 `data.tts_this_turn` 控制(未传或 false 仅返回文字)。 +# 若 ENABLE_TTS=true 且该轮 `tts_this_turn=true`:每一段助手文案先下发 `tts_audio`,再下发对应段的 `agent_response`。 ENABLE_TTS=true TTS_PROVIDER=tencent # 仅 TTS_PROVIDER=openai 时需要(填控制台密钥;勿在注释行写 =your_* 以免旧版 CI 误匹配) diff --git a/api/.env.staging b/api/.env.staging index 81dc5d2..dab7bd9 100644 --- a/api/.env.staging +++ b/api/.env.staging @@ -1,3 +1,7 @@ +LIFE_ECHO_API_HOST_BIND=0.0.0.0 +LIFE_ECHO_API_HOST_PORT=8000 +POSTGRES_HOST_PORT=15432 + # ============================================================================= # Life Echo API — staging(预发) # @@ -119,9 +123,11 @@ TENCENT_SECRET_ID=your_tencent_asr_secret_id TENCENT_SECRET_KEY=your_tencent_asr_secret_key # ============================================================================= -# TTS(文字转语音,Agent 回复播音)— 与 ASR 独立 +# TTS(文字转语音,Agent 回复朗读)— 与 ASR 独立 # ============================================================================= -# ENABLE_TTS:仅控制是否合成并下发 TTS_AUDIO;不影响用户语音转写(ASR) +# ENABLE_TTS:是否启用「助手回复朗读」服务端能力(TTS 适配器与密钥配置)。关则永远不合成。 +# 每轮是否实际合成:由客户端在 WebSocket `text` / `audio_segment` / `audio_message` 的 `data.tts_this_turn` 控制(未传或 false 仅返回文字)。 +# 若 ENABLE_TTS=true 且该轮 `tts_this_turn=true`:每一段助手文案先下发 `tts_audio`,再下发对应段的 `agent_response`。 ENABLE_TTS=true TTS_PROVIDER=tencent # 仅 TTS_PROVIDER=openai 时需要 diff --git a/api/README.md b/api/README.md index a5a2474..4e572f7 100644 --- a/api/README.md +++ b/api/README.md @@ -10,7 +10,7 @@ Life Echo API 是一个智能对话系统,通过 WebSocket 实时连接,使 - **会话真源**:`conversation_messages`(DB)+ Redis 缓存;**实时编排入口**:`ChatOrchestrator`。 - **图像管线**:正文主图 `generate_story_image`;章节封面 `try_enqueue_generate_chapter_cover` → `generate_chapter_cover`。 -- **回忆录批次**:`MemoirOrchestrator.prepare_batches` 显式分桶后,`process_memoir_segments` 按类别加锁并调用 `run_story_pipeline_for_category_batch`(含 `StoryRouteAgent.plan_batch` 多 unit 写入)。 +- **回忆录批次**:`MemoirOrchestrator.prepare_batches` 显式分桶后,`process_memoir_phase1` 派发 Phase 2 按类别调用 `run_story_pipeline_for_category_batch`(含 `StoryRouteAgent.plan_batch` 多 unit 写入)。 ### LLM 与记忆(约定文档) @@ -90,11 +90,11 @@ LLM_BASE_URL=https://api.your-llm-provider.com # 可选 LLM_MODEL=your-model-name # 可选,默认 deepseek-chat LLM_TEMPERATURE=0.7 # 可选,默认 0.7 -# 数据库配置(PostgreSQL,推荐) -DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo +# 数据库配置(本地用 docker-compose.dev.yml 时为固定端口 48291,见下文「本地开发」) +DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo -# Redis 配置 -REDIS_URL=redis://localhost:6379/0 +# Redis 配置(本地 compose.dev 固定端口 48307) +REDIS_URL=redis://localhost:48307/0 # 认证配置 SECRET_KEY=your-secret-key-here # JWT签名密钥(建议使用随机字符串) @@ -152,9 +152,9 @@ docker compose -f docker-compose.dev.yml up -d # 2. 安装依赖 pip install -r requirements.txt -# 3. 配置环境变量 -export DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo -export REDIS_URL=redis://localhost:6379/0 +# 3. 配置环境变量(与 docker-compose.dev.yml 固定宿主端口一致:Postgres 48291、Redis 48307) +export DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo +export REDIS_URL=redis://localhost:48307/0 # 4. 启动 API(终端 1) uvicorn main:app --reload --host 0.0.0.0 --port 8000 diff --git a/api/alembic/versions/0005_cleanup_cross_chapter_story_links.py b/api/alembic/versions/0005_cleanup_cross_chapter_story_links.py index 0a06c5a..cf7d7b5 100644 --- a/api/alembic/versions/0005_cleanup_cross_chapter_story_links.py +++ b/api/alembic/versions/0005_cleanup_cross_chapter_story_links.py @@ -9,6 +9,7 @@ Revises: 0004_memory_embedding_1024 from typing import Sequence, Union import sqlalchemy as sa +from sqlalchemy.dialects import postgresql from alembic import op @@ -18,7 +19,42 @@ branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None +def _has_column(table: str, column: str) -> bool: + bind = op.get_bind() + return any(c["name"] == column for c in sa.inspect(bind).get_columns(table)) + + +def _ensure_chapter_materialization_columns() -> None: + """Keep older/squashed staging schemas compatible before this data cleanup.""" + if not _has_column("chapters", "markdown_compose_dirty"): + op.add_column( + "chapters", + sa.Column( + "markdown_compose_dirty", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + if not _has_column("chapters", "markdown_composed_at"): + op.add_column( + "chapters", + sa.Column("markdown_composed_at", sa.DateTime(timezone=True), nullable=True), + ) + if not _has_column("chapters", "reading_segments_json"): + op.add_column( + "chapters", + sa.Column( + "reading_segments_json", + postgresql.JSON(astext_type=sa.Text()), + nullable=True, + ), + ) + + def upgrade() -> None: + _ensure_chapter_materialization_columns() + # 先标脏,再删链接(子查询在 DELETE 后不可用) op.execute( sa.text( diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index 86b4328..a5f738a 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -22,7 +22,6 @@ from app.agents.chat.reply_limits import ( from app.agents.chat.schemas import ProfileExtractionOutput from app.core.agent_logging import agent_span, log_agent_payload, log_agent_summary from app.core.config import settings -from app.core.dependencies import get_llm_provider from app.core.llm_call import allm_json_call from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger @@ -31,11 +30,53 @@ from app.ports.llm import LLMProvider logger = get_logger(__name__) -def _get_langchain_llm(): - try: - return LlmGateway().langchain_llm_for(LlmUseCase("chat.profile")) - except Exception: - return None +class _ProviderBackedProfileGateway: + def __init__(self, provider: LLMProvider) -> None: + self._provider = provider + + async def chat_text( + self, + messages: list[dict], + *, + use_case: LlmUseCase | None = None, + temperature: float | None = None, + model: str | None = None, + max_tokens: int | None = None, + ) -> str: + resolved_temperature = temperature + if resolved_temperature is None: + resolved_temperature = ( + use_case.temperature + if use_case and use_case.temperature is not None + else 0.7 + ) + return await self._provider.complete( + messages, + temperature=resolved_temperature, + model=model if model is not None else (use_case.model if use_case else None), + max_tokens=( + max_tokens + if max_tokens is not None + else (use_case.max_tokens if use_case else None) + ), + ) + + async def json_object( + self, + prompt: str, + schema: type[ProfileExtractionOutput], + *, + use_case: LlmUseCase, + fallback_factory: Any = None, + ) -> ProfileExtractionOutput: + return await allm_json_call( + getattr(self._provider, "langchain_llm", None), + prompt, + schema, + max_tokens=use_case.max_tokens or 1024, + agent=use_case.name, + fallback_factory=fallback_factory, + ) def _langchain_messages_to_port(messages: List[Any]) -> list[dict]: @@ -66,14 +107,17 @@ def _message_contents_char_count(messages: List[Any]) -> int: class ProfileAgent: """用户资料收集 Specialist Agent""" - def __init__(self, llm_provider: LLMProvider | None = None): - self._llm_provider = llm_provider - self.llm = _get_langchain_llm() - - def _provider(self) -> LLMProvider: - if self._llm_provider is not None: - return self._llm_provider - return get_llm_provider() + def __init__( + self, + llm_provider: LLMProvider | None = None, + llm_gateway: Any | None = None, + ) -> None: + if llm_gateway is not None: + self._llm_gateway = llm_gateway + elif llm_provider is not None: + self._llm_gateway = _ProviderBackedProfileGateway(llm_provider) + else: + self._llm_gateway = LlmGateway() async def _invoke_chat( self, @@ -88,8 +132,9 @@ class ProfileAgent: with agent_span( logger, f"{agent_name}.llm", conversation_id=conversation_id or "" ): - response_text = await self._provider().complete( + response_text = await self._llm_gateway.chat_text( port_messages, + use_case=LlmUseCase("chat.profile", max_tokens=max_tokens), max_tokens=max_tokens, ) logger.info( @@ -130,7 +175,7 @@ class ProfileAgent: conversation_id: Optional[str] = None, ) -> Dict[str, Any]: """从用户消息中提取资料字段,不持久化""" - if not self.llm or not missing_fields: + if not missing_fields: return {} recent_dialogue = "" if conversation_id: @@ -151,12 +196,13 @@ class ProfileAgent: prompt = get_profile_extraction_prompt( user_message, missing_fields, recent_dialogue=recent_dialogue or None ) - parsed = await allm_json_call( - self.llm, + parsed = await self._llm_gateway.json_object( prompt, ProfileExtractionOutput, - max_tokens=settings.chat_profile_extract_max_tokens, - agent="ProfileAgent.extract_profile_from_message", + use_case=LlmUseCase( + "ProfileAgent.extract_profile_from_message", + max_tokens=settings.chat_profile_extract_max_tokens, + ), fallback_factory=lambda: ProfileExtractionOutput(), ) result = {} @@ -197,8 +243,6 @@ class ProfileAgent: interview_stage_hint: str = "", ) -> List[str]: """生成资料追问回复,不持久化(由 Orchestrator 负责)""" - if not self.llm: - return ["谢谢!还能告诉我更多吗?"] try: prompt = get_profile_followup_prompt( missing_fields, @@ -260,8 +304,6 @@ class ProfileAgent: nickname: str = "", ) -> List[str]: """生成资料收集开场白,不持久化(由 Orchestrator 负责)""" - if not self.llm: - return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"] try: prompt = get_profile_greeting_prompt(missing_fields, nickname) hw = await get_history_with_window( diff --git a/api/app/agents/image_prompt/orchestrator.py b/api/app/agents/image_prompt/orchestrator.py index 89ecb57..332517c 100644 --- a/api/app/agents/image_prompt/orchestrator.py +++ b/api/app/agents/image_prompt/orchestrator.py @@ -9,8 +9,12 @@ from __future__ import annotations from typing import Any, Optional from app.agents.image_prompt.prompt_agent import PromptGenerationAgent +from app.core.config import settings +from app.core.logging import get_logger from app.features.memoir.memoir_images.settings import MemoirImageSettings +logger = get_logger(__name__) + class ImagePromptOrchestrator: """ @@ -76,5 +80,15 @@ def get_image_prompt_orchestrator() -> ImagePromptOrchestrator: """Celery / 后台任务入口:统一装配 LLM 与 MemoirImageSettings。""" from app.core.llm_gateway import LlmGateway, LlmUseCase - llm = LlmGateway().langchain_llm_for(LlmUseCase("image_prompt")) - return ImagePromptOrchestrator(llm=llm, settings=MemoirImageSettings.from_env()) + image_settings = MemoirImageSettings.from_env() + try: + llm = LlmGateway().langchain_llm_for(LlmUseCase("image_prompt")) + except Exception as e: + if settings.image_prompt_fallback_disabled: + raise + logger.warning( + "ImagePromptOrchestrator LLM 初始化失败,使用确定性 fallback: {}", + e, + ) + llm = None + return ImagePromptOrchestrator(llm=llm, settings=image_settings) diff --git a/api/app/agents/memoir/batch_phase1_prep.py b/api/app/agents/memoir/batch_phase1_prep.py index f76ac87..829ceca 100644 --- a/api/app/agents/memoir/batch_phase1_prep.py +++ b/api/app/agents/memoir/batch_phase1_prep.py @@ -10,7 +10,6 @@ from typing import Any, Callable, Dict, List from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt from app.agents.memoir.schemas import BatchPhase1LLMOutput -from app.agents.stage_constants import STAGE_SLOT_KEYS from app.agents.state_schema import MemoirStateSchema from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call @@ -19,11 +18,6 @@ from app.features.conversation.models import Segment logger = get_logger(__name__) -STAGE_ALLOWED_SLOTS: Dict[str, frozenset[str]] = { - k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items() -} - - def _slots_snapshot(state: MemoirStateSchema) -> dict: snap: dict = {} for stage, buckets in (state.slots or {}).items(): diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index 2ecbe20..56d7091 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -8,12 +8,9 @@ from __future__ import annotations import time from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set -from app.agents.memoir.batch_phase1_prep import ( - STAGE_ALLOWED_SLOTS, - run_batch_phase1_prep_chunked, -) +from app.agents.memoir.batch_phase1_prep import run_batch_phase1_prep_chunked from app.agents.memoir.classification_agent import ( ClassificationAgent, _looks_like_fragment_only, @@ -22,7 +19,11 @@ from app.agents.memoir.classification_agent import ( _detect_stage as detect_stage_from_keywords, ) from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult -from app.agents.stage_constants import normalize_chapter_category, normalize_chat_stage +from app.agents.stage_constants import ( + filter_stage_slots, + normalize_chapter_category, + normalize_chat_stage, +) from app.agents.state_schema import MemoirStateSchema from app.core.agent_logging import agent_span, agent_summary_enabled, log_agent_detail from app.core.config import settings @@ -92,7 +93,7 @@ class MemoirOrchestrator: ) if use_batch: try: - result = self._prepare_batches_via_batch_llm( + prepared_batch = self._prepare_batches_via_batch_llm( segments=segments, state=state, classify_extract_llm=classify_extract_llm, @@ -104,7 +105,7 @@ class MemoirOrchestrator: "msg=Phase1 批处理 LLM 路径已使用", len(segments), ) - return result + return prepared_batch except Exception as e: logger.warning( "event=phase1_batch_path_fallback segment_count={} exc={} " @@ -132,8 +133,12 @@ class MemoirOrchestrator: stage_slots=stage_slots_raw, llm=classify_extract_llm, ) - detected_stage = result.detected_stage - for slot_name, snippet in result.slots.items(): + fb = state.current_stage or "childhood" + detected_stage = normalize_chat_stage(result.detected_stage, fb) + result_slots = filter_stage_slots(detected_stage, result.slots, fb) + if not result_slots: + detected_stage = normalize_chat_stage(fb, fb) + for slot_name, snippet in result_slots.items(): state = update_slot(detected_stage, slot_name, snippet, [segment.id]) with agent_span( @@ -148,7 +153,7 @@ class MemoirOrchestrator: segment_id=segment.id, ) chapter_category = classify_result.category - if (not result.slots) and classify_result.llm_said_none: + if (not result_slots) and classify_result.llm_said_none: segment_skip_story_ids.add(str(segment.id)) segment_chapter_category[str(segment.id)] = chapter_category @@ -166,7 +171,7 @@ class MemoirOrchestrator: logger, "MemoirOrchestrator.segment_done segment_id={} slots={}", segment.id, - list((result.slots or {}).keys()), + list(result_slots.keys()), ) category_to_segments.setdefault(chapter_category, []).append(segment) @@ -211,8 +216,7 @@ class MemoirOrchestrator: else: detected_stage = normalize_chat_stage(row.detected_stage, fb) - allowed = STAGE_ALLOWED_SLOTS.get(detected_stage, frozenset()) - result_slots = {k: v for k, v in result_slots.items() if k in allowed} + result_slots = filter_stage_slots(detected_stage, result_slots, fb) if not result_slots: detected_stage = normalize_chat_stage(fb, fb) @@ -269,72 +273,3 @@ class MemoirOrchestrator: segment_skip_story_ids=segment_skip_story_ids, segment_chapter_category=segment_chapter_category, ) - - def run( - self, - *, - segments: List[Segment], - llm: Any, - user_profile: str = "", - user_birth_year: Any = None, - get_or_create_state: Callable[[], MemoirStateSchema], - update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], - acquire_lock: Callable[[str], bool], - release_lock: Callable[[str], None], - process_category: Callable[ - [ - str, - List[Segment], - MemoirStateSchema, - str, - Any, - Any, - ], - Tuple[Any, bool], - ], - raise_retry: Callable[[], None], - llm_fast: Any | None = None, - ) -> Tuple[Set[str], int]: - """ - 执行回忆录流水线。 - process_category(category, segments, state, user_profile, user_birth_year, llm) - 返回 (chapter, has_images_to_generate)。 - 返回 (chapters_to_enqueue, processed_count)。 - raise_retry 用于锁竞争时抛出 Celery retry。 - """ - prepared = self.prepare_batches( - segments=segments, - llm=llm, - llm_fast=llm_fast, - get_or_create_state=get_or_create_state, - update_slot=update_slot, - on_phase1_chunk=None, - ) - state = prepared.state - chapters_to_enqueue: Set[str] = set() - category_to_segments = prepared.category_to_segments - - # 按 category 调用 process_category:叙事生成、持久化、封面入队标记 - for chapter_category, category_segments in category_to_segments.items(): - if not acquire_lock(chapter_category): - logger.warning( - "章节锁竞争: category={}, 延迟重试", - chapter_category, - ) - raise_retry() - - try: - chapter, has_images = process_category( - chapter_category, - category_segments, - state, - user_profile, - user_birth_year, - llm, - ) - if chapter and has_images: - chapters_to_enqueue.add(chapter.id) - finally: - release_lock(chapter_category) - - return chapters_to_enqueue, len(segments) diff --git a/api/app/agents/memoir/prompts.py b/api/app/agents/memoir/prompts.py index 4a5476c..e636d21 100644 --- a/api/app/agents/memoir/prompts.py +++ b/api/app/agents/memoir/prompts.py @@ -11,7 +11,6 @@ from app.agents.chat.background_voice import get_background_voice_narrative_bloc from app.agents.chat.occupation_context import get_occupation_narrative_hint from app.agents.stage_constants import STAGE_ERA_HINTS, STAGE_SLOT_KEYS from app.agents.style_profiles import MemoirStyleProfile -from app.features.memory.evidence_format import format_evidence_chunks_for_prompt def _memoir_fidelity_core_rules() -> str: diff --git a/api/app/agents/memoir/story_route_agent.py b/api/app/agents/memoir/story_route_agent.py index b18dab1..6e1a00d 100644 --- a/api/app/agents/memoir/story_route_agent.py +++ b/api/app/agents/memoir/story_route_agent.py @@ -31,6 +31,9 @@ PLAN_BATCH_MAX_SEGMENTS = 48 # 童年 / 求学 / 家庭:模型与后处理均倾向「少拆分、优先续写」 APPEND_FIRST_CHAPTER_CATEGORIES = frozenset({"childhood", "education", "family"}) +# These route outcomes are conservative fail-safes, not semantic append matches. +FALLBACK_NEW_STORY_REASONS = frozenset({"no_llm", "parse_error", "invalid_target"}) + def default_append_target_story_id( candidate_stories: list[Story], @@ -220,13 +223,6 @@ class StoryRouteAgent: story_meta: dict[str, dict[str, int]] | None = None, ) -> StoryRouteDecision: if not llm: - fb = default_append_target_story_id(candidate_stories, story_meta, settings) - if fb and fb in valid_story_ids: - return StoryRouteDecision( - decision="append_story", - target_story_id=fb, - reason="no_llm_default_append", - ) return StoryRouteDecision( decision="new_story", new_story_title=None, @@ -241,13 +237,6 @@ class StoryRouteAgent: ) def _decide_fallback() -> StoryRouteDecision: - fb = default_append_target_story_id(candidate_stories, story_meta, settings) - if fb and fb in valid_story_ids: - return StoryRouteDecision( - decision="append_story", - target_story_id=fb, - reason="parse_error_default_append", - ) return StoryRouteDecision( decision="new_story", new_story_title=None, @@ -266,22 +255,8 @@ class StoryRouteAgent: if decision.decision == "append_story": tid = decision.target_story_id if not tid or tid not in valid_story_ids: - fb = default_append_target_story_id( - candidate_stories, story_meta, settings - ) - if fb and fb in valid_story_ids: - logger.info( - "StoryRoute append 无效 target_story_id={},回退默认 append {}", - tid, - fb, - ) - return StoryRouteDecision( - decision="append_story", - target_story_id=fb, - reason="invalid_target_default_append", - ) logger.warning( - "StoryRoute append 无效 target_story_id={},且无可用默认目标,回退 new_story", + "StoryRoute append 无效 target_story_id={},回退 new_story", tid, ) return StoryRouteDecision( diff --git a/api/app/agents/stage_constants.py b/api/app/agents/stage_constants.py index 0ff3001..4281831 100644 --- a/api/app/agents/stage_constants.py +++ b/api/app/agents/stage_constants.py @@ -68,6 +68,35 @@ STAGE_SLOT_KEYS: dict[str, tuple[str, ...]] = { "belief": ("value", "regret", "pride", "lesson"), } +STAGE_ALLOWED_SLOTS: dict[str, frozenset[str]] = { + k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items() +} + + +def allowed_slot_names_for_stage( + stage: str | None, + fallback: str = "childhood", +) -> frozenset[str]: + stage_norm = normalize_chat_stage(stage, fallback=fallback) + return STAGE_ALLOWED_SLOTS.get(stage_norm, frozenset()) + + +def is_valid_stage_slot( + stage: str | None, + slot_name: str, + fallback: str = "childhood", +) -> bool: + return slot_name in allowed_slot_names_for_stage(stage, fallback=fallback) + + +def filter_stage_slots( + stage: str | None, + slots: dict[str, str], + fallback: str = "childhood", +) -> dict[str, str]: + allowed = allowed_slot_names_for_stage(stage, fallback=fallback) + return {k: v for k, v in (slots or {}).items() if k in allowed} + # 人生阶段 / 章节类目的年龄参照(仅用于 prompt 时间提示;非业务校验) STAGE_ERA_HINTS: dict[str, tuple[int, int]] = { "childhood": (0, 12), diff --git a/api/app/agents/style_profiles.py b/api/app/agents/style_profiles.py index 018a511..e15cb89 100644 --- a/api/app/agents/style_profiles.py +++ b/api/app/agents/style_profiles.py @@ -22,7 +22,6 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import List, Tuple - # ============================================================================= # 共享:Memoir 评测维度单一事实源 # ============================================================================= diff --git a/api/app/core/config.py b/api/app/core/config.py index 8adceda..24a0347 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -355,6 +355,13 @@ class Settings(BaseSettings): memoir_recompose_retry_on_lock_contention: bool = True # Phase2 立即派发使用固定 task_id,减少同类目重复入队(超时任务仍用独立 id) memoir_phase2_singleflight_immediate: bool = True + # True:Phase2 路由低置信(no_llm/parse_error/invalid_target)时不写 Story, + # 把 segment 标记为 narrative_deferred_until 之后再重试。 + memoir_route_defer_enabled: bool = True + # 低置信延迟时长(秒):到期前不消费这些 segment,避免后台空转 + memoir_route_defer_seconds: float = Field(default=120.0, ge=1.0, le=3600.0) + # 同一类目最多自动延迟次数;达到上限后 segment 仅靠新素材到达激活,不再自动重试 + memoir_route_defer_max_attempts: int = Field(default=3, ge=1, le=20) # True:Phase2 首稿后异步运行质量增强(fidelity recheck、标题润色、LLM 归一) memoir_quality_pass_enabled: bool = True memoir_quality_pass_delay_seconds: int = Field(default=5, ge=0, le=300) diff --git a/api/app/core/llm_gateway.py b/api/app/core/llm_gateway.py index 2b10b88..3e7c16d 100644 --- a/api/app/core/llm_gateway.py +++ b/api/app/core/llm_gateway.py @@ -49,13 +49,18 @@ class LlmGateway: max_tokens: int | None = None, ) -> str: provider = self.provider_for(use_case) + resolved_temperature = ( + temperature + if temperature is not None + else ( + use_case.temperature + if use_case and use_case.temperature is not None + else 0.7 + ) + ) return await provider.complete( messages, - temperature=( - temperature - if temperature is not None - else (use_case.temperature if use_case else 0.7) - ), + temperature=resolved_temperature, model=model if model is not None else (use_case.model if use_case else None), max_tokens=( max_tokens diff --git a/api/app/core/log_events.py b/api/app/core/log_events.py index 788d5a4..cb16fdd 100644 --- a/api/app/core/log_events.py +++ b/api/app/core/log_events.py @@ -58,6 +58,7 @@ def correlation_bind_kwargs( # bind=True 任务的 positional 与字段名映射(kwargs 优先,缺位再填) _TASK_POSITIONAL_FIELDS: dict[str, tuple[str, ...]] = { + "app.tasks.memory_enrichment_tasks.embed_memory_source": ("user_id", "source_id"), "app.tasks.memory_enrichment_tasks.enrich_memory_source": ("user_id", "source_id"), "app.tasks.memory_compaction_tasks.memory_compaction_run": ("user_id",), "app.tasks.chapter_compose_tasks.recompose_chapter": ("chapter_id",), diff --git a/api/app/features/auth/avatar_presets/01.png b/api/app/features/auth/avatar_presets/01.png new file mode 100644 index 0000000..08742d1 Binary files /dev/null and b/api/app/features/auth/avatar_presets/01.png differ diff --git a/api/app/features/auth/avatar_presets/02.png b/api/app/features/auth/avatar_presets/02.png new file mode 100644 index 0000000..4e845b8 Binary files /dev/null and b/api/app/features/auth/avatar_presets/02.png differ diff --git a/api/app/features/auth/avatar_presets/03.png b/api/app/features/auth/avatar_presets/03.png new file mode 100644 index 0000000..192122f Binary files /dev/null and b/api/app/features/auth/avatar_presets/03.png differ diff --git a/api/app/features/auth/avatar_presets/04.png b/api/app/features/auth/avatar_presets/04.png new file mode 100644 index 0000000..8e4d101 Binary files /dev/null and b/api/app/features/auth/avatar_presets/04.png differ diff --git a/api/app/features/auth/avatar_presets/05.png b/api/app/features/auth/avatar_presets/05.png new file mode 100644 index 0000000..5920a73 Binary files /dev/null and b/api/app/features/auth/avatar_presets/05.png differ diff --git a/api/app/features/auth/avatar_presets/06.png b/api/app/features/auth/avatar_presets/06.png new file mode 100644 index 0000000..72e48ad Binary files /dev/null and b/api/app/features/auth/avatar_presets/06.png differ diff --git a/api/app/features/auth/avatar_presets/07.png b/api/app/features/auth/avatar_presets/07.png new file mode 100644 index 0000000..26d2f84 Binary files /dev/null and b/api/app/features/auth/avatar_presets/07.png differ diff --git a/api/app/features/auth/avatar_presets/08.png b/api/app/features/auth/avatar_presets/08.png new file mode 100644 index 0000000..23e5f4a Binary files /dev/null and b/api/app/features/auth/avatar_presets/08.png differ diff --git a/api/app/features/auth/preset_avatars.py b/api/app/features/auth/preset_avatars.py new file mode 100644 index 0000000..e4a5c14 --- /dev/null +++ b/api/app/features/auth/preset_avatars.py @@ -0,0 +1,68 @@ +"""服务端托管的预设头像(白名单文件名)。""" + +from __future__ import annotations + +from pathlib import Path + +_PRESETS_DIR = Path(__file__).resolve().parent / "avatar_presets" + +# 与仓库内静态文件一致:01.png … 08.png +ALLOWED_PRESET_FILENAMES: frozenset[str] = frozenset(f"{i:02d}.png" for i in range(1, 9)) + +PRESET_IDS: tuple[str, ...] = tuple(f"{i:02d}" for i in range(1, 9)) + + +def preset_filename_for_id(preset_id: str) -> str | None: + """preset_id 形如 \"01\",返回 \"01.png\";非法则 None。""" + stripped = preset_id.strip() + name = f"{stripped}.png" + if name in ALLOWED_PRESET_FILENAMES: + return name + return None + + +def avatar_url_for_preset_filename(filename: str) -> str: + return f"/api/auth/avatar-presets/{filename}" + + +def list_preset_items() -> list[tuple[str, str]]: + """(preset_id, avatar_url) 列表,供 GET /avatar-presets。""" + return [ + (pid, avatar_url_for_preset_filename(f"{pid}.png")) for pid in PRESET_IDS + ] + + +def preset_file_path(filename: str) -> Path | None: + if filename not in ALLOWED_PRESET_FILENAMES: + return None + path = (_PRESETS_DIR / filename).resolve() + try: + path.relative_to(_PRESETS_DIR.resolve()) + except ValueError: + return None + return path + + +def _avatar_upload_stem_allowed(stem: str) -> bool: + """允许 UUID、数字 ID、含连字符/下划线的安全文件名主干。""" + if not stem or len(stem) > 128: + return False + return all(c.isalnum() or c in "-_" for c in stem) + + +def safe_avatar_upload_path(filename: str, avatar_dir: Path) -> Path | None: + """用户上传头像文件名形如 {user_id}.jpg,防路径穿越。""" + if "/" in filename or "\\" in filename or ".." in filename: + return None + if not filename.endswith(".jpg"): + return None + stem = filename[:-4] + if not _avatar_upload_stem_allowed(stem): + return None + base = avatar_dir.resolve() + path = (avatar_dir / filename).resolve() + try: + path.relative_to(base) + except ValueError: + return None + return path diff --git a/api/app/features/auth/router.py b/api/app/features/auth/router.py index 68e230b..c6cd9fd 100644 --- a/api/app/features/auth/router.py +++ b/api/app/features/auth/router.py @@ -9,7 +9,15 @@ from app.core.config import settings from app.core.dependencies import get_current_user from app.core.logging import get_logger from app.features.auth.deps import get_auth_service +from app.features.auth.preset_avatars import ( + avatar_url_for_preset_filename, + list_preset_items, + preset_filename_for_id, + preset_file_path, + safe_avatar_upload_path, +) from app.features.auth.schemas import ( + AvatarPresetItem, ChangePasswordRequest, ChangePhoneRequest, LoginRequest, @@ -18,6 +26,7 @@ from app.features.auth.schemas import ( RegisterRequest, ResetPasswordRequest, SendSmsRequest, + SetAvatarPresetRequest, SmsLoginRequest, SmsRegisterRequest, TokenResponse, @@ -329,14 +338,72 @@ async def upload_avatar( ) from e +@router.get( + "/avatar-presets", + response_model=list[AvatarPresetItem], + summary="预设头像列表", +) +async def list_avatar_presets(): + return [ + AvatarPresetItem(id=item_id, url=item_url) + for item_id, item_url in list_preset_items() + ] + + +@router.put( + "/me/avatar/preset", + response_model=UserResponse, + summary="使用预设头像", + responses={400: {"description": "无效的预设编号"}}, +) +async def set_avatar_preset( + request: SetAvatarPresetRequest, + current_user: User = Depends(get_current_user), + service: AuthService = Depends(get_auth_service), +): + filename = preset_filename_for_id(request.preset_id) + if filename is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的预设头像编号", + ) + path = preset_file_path(filename) + if path is None or not path.exists(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="预设头像不可用", + ) + avatar_url = avatar_url_for_preset_filename(filename) + try: + user = await service.update_avatar_url(current_user.id, avatar_url) + except AuthError as e: + raise _map_auth_error(e) + return _user_response(user) + + +@router.get( + "/avatar-presets/{filename}", + summary="获取预设头像图片", + responses={404: {"description": "预设不存在"}}, +) +async def get_avatar_preset(filename: str): + path = preset_file_path(filename) + if path is None or not path.exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="预设头像不存在", + ) + return FileResponse(path, media_type="image/png") + + @router.get( "/avatars/{filename}", summary="获取头像图片", responses={404: {"description": "头像不存在"}}, ) async def get_avatar(filename: str): - file_path = AVATAR_DIR / filename - if not file_path.exists(): + file_path = safe_avatar_upload_path(filename, AVATAR_DIR) + if file_path is None or not file_path.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="头像不存在", diff --git a/api/app/features/auth/schemas.py b/api/app/features/auth/schemas.py index b3c5dba..15bac34 100644 --- a/api/app/features/auth/schemas.py +++ b/api/app/features/auth/schemas.py @@ -104,3 +104,18 @@ class UpdateNicknameRequest(BaseModel): class AvatarUploadResponse(BaseModel): avatar_url: str + + +class SetAvatarPresetRequest(BaseModel): + preset_id: str = Field( + ..., + min_length=2, + max_length=2, + pattern=r"^\d{2}$", + description="预设编号,如 01–08", + ) + + +class AvatarPresetItem(BaseModel): + id: str + url: str diff --git a/api/app/features/conversation/models.py b/api/app/features/conversation/models.py index 36f3bfe..56b1306 100644 --- a/api/app/features/conversation/models.py +++ b/api/app/features/conversation/models.py @@ -58,6 +58,13 @@ class Segment(Base): narrated = Column(Boolean, default=False, server_default="false") # Phase 1 判定无需进故事管线(无 slots 且 LLM 判 none) skip_narrative = Column(Boolean, default=False, server_default="false") + # Phase 2 路由低置信延迟:到期前不消费;新同类目素材到达可清空。 + narrative_deferred_until = Column(DateTime(timezone=True), nullable=True) + narrative_defer_count = Column( + Integer, nullable=False, default=0, server_default="0" + ) + narrative_defer_reason = Column(String, nullable=True) + narrative_last_attempt_at = Column(DateTime(timezone=True), nullable=True) agent_response = Column(Text, nullable=True) tts_audio_urls = Column(JSON, nullable=True) # 用户轮次 durable message id(与 lineage_json 同步;便于查询) diff --git a/api/app/features/conversation/service.py b/api/app/features/conversation/service.py index 6a9b3ad..5668fe3 100644 --- a/api/app/features/conversation/service.py +++ b/api/app/features/conversation/service.py @@ -96,6 +96,9 @@ def _build_messages_from_history( tts = msg.get("ttsAudioUrls") if isinstance(tts, list) and tts: item["ttsAudioUrls"] = [x for x in tts if isinstance(x, str)] + dm = msg.get("durableMessageId") + if isinstance(dm, str) and dm: + item["durableMessageId"] = dm messages.append(item) return messages diff --git a/api/app/features/conversation/session_history.py b/api/app/features/conversation/session_history.py index 8d7926f..0573151 100644 --- a/api/app/features/conversation/session_history.py +++ b/api/app/features/conversation/session_history.py @@ -18,6 +18,7 @@ def conversation_messages_to_redis_history( "content": row.content, "messageType": row.message_type, "timestamp": row.created_at.isoformat() if row.created_at else None, + "durableMessageId": row.id, } if row.voice_session_id: item["voiceSessionId"] = row.voice_session_id diff --git a/api/app/features/conversation/tts_delivery.py b/api/app/features/conversation/tts_delivery.py index 762b5af..6f3c717 100644 --- a/api/app/features/conversation/tts_delivery.py +++ b/api/app/features/conversation/tts_delivery.py @@ -9,9 +9,15 @@ from __future__ import annotations -from app.core.cos_url_keys import presign_tts_urls_for_playback +from app.core.cos_url_keys import ( + TTS_PRESIGNED_EXPIRES_SEC, + extract_cos_object_key_if_owned, +) +from app.core.logging import get_logger from app.ports.storage import ObjectStorage +logger = get_logger(__name__) + def apply_presigned_tts_urls_to_messages( messages: list[dict], @@ -24,5 +30,26 @@ def apply_presigned_tts_urls_to_messages( tts = m.get("ttsAudioUrls") if not isinstance(tts, list) or not tts: continue - str_urls = [x for x in tts if isinstance(x, str)] - m["ttsAudioUrls"] = presign_tts_urls_for_playback(str_urls, storage) + out: list[str] = [] + for x in tts: + if not isinstance(x, str): + out.append("") + continue + s = x.strip() + if not s: + out.append("") + continue + key = extract_cos_object_key_if_owned(s) + if key: + try: + out.append(storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)) + except Exception as exc: + logger.warning( + "presign tts url failed, keeping original url: key={} err={}", + key, + exc, + ) + out.append(s) + else: + out.append(s) + m["ttsAudioUrls"] = out diff --git a/api/app/features/conversation/ws/message_types.py b/api/app/features/conversation/ws/message_types.py index 3aafd99..8f5883d 100644 --- a/api/app/features/conversation/ws/message_types.py +++ b/api/app/features/conversation/ws/message_types.py @@ -17,6 +17,7 @@ class MessageType(str, Enum): AGENT_RESPONSE = "agent_response" TTS_AUDIO = "tts_audio" TTS_CANCEL = "tts_cancel" + TTS_REQUEST = "tts_request" PING = "ping" PONG = "pong" END_CONVERSATION = "end_conversation" diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index aa028db..ee96511 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -18,9 +18,13 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat import ChatOrchestrator +from app.agents.chat.reply_limits import segments_from_llm_response from app.core.agent_logging import agent_summary_enabled from app.core.config import settings -from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC +from app.core.cos_url_keys import ( + TTS_PRESIGNED_EXPIRES_SEC, + extract_cos_object_key_if_owned, +) from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider from app.features.conversation.chat_turn import ( @@ -33,7 +37,7 @@ from app.features.conversation.history_store import ( ConversationHistoryStore, ) from app.features.conversation.lineage_schemas import DialogueLineage -from app.features.conversation.models import Conversation, Segment +from app.features.conversation.models import Conversation, ConversationMessage, Segment from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.profile_collector import ( @@ -84,6 +88,7 @@ async def _send_tts_audio( chunk_total: int, assistant_message_id: str | None, tts_epoch_start: int, + manual: bool = False, ) -> str | None: """Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None.""" if not settings.enable_tts: @@ -116,6 +121,8 @@ async def _send_tts_audio( } if assistant_message_id: payload_data["assistant_message_id"] = assistant_message_id + if manual: + payload_data["manual"] = True await manager.send_message( conversation_id, { @@ -138,6 +145,109 @@ async def _send_tts_audio( return None +async def handle_tts_request_on_demand( + *, + conversation_id: str, + user_id: str, + assistant_message_id: str, + segment_index: int, + segment_text: str | None, + db: AsyncSession, +) -> tuple[bool, str]: + """用户点喇叭:该段已有 TTS 则预签名下发;否则合成后落库并下发。不重复合成同一段。""" + if not settings.enable_tts: + return False, "未开启语音合成" + + conv = await db.get(Conversation, conversation_id) + if not conv or conv.user_id != user_id or conv.deleted_at is not None: + return False, "对话不存在或无权访问" + + msg = await db.get(ConversationMessage, assistant_message_id) + if not msg or msg.conversation_id != conversation_id or msg.role != "ai": + return False, "消息不存在" + + # 与客户端 splitMessageParts / segments_from_llm_response 对齐(含无 [SPLIT] 时的段落拆段) + parts = segments_from_llm_response(msg.content or "", max_segments=3) + if segment_index < 0 or segment_index >= len(parts): + return False, "分段序号无效" + + canon = (parts[segment_index] or "").strip() + if not canon: + return False, "该段无朗读文本" + if segment_text and segment_text.strip() and segment_text.strip() != canon: + logger.debug( + "按需 TTS: 客户端传入 segment_text 与规范化后 canon 不完全一致,已按 segment_index 朗读 canon " + "(client_len={} canon_len={})", + len(segment_text.strip()), + len(canon), + ) + + urls: List[str] = [] + for x in msg.tts_audio_urls or []: + if isinstance(x, str) and x.strip(): + urls.append(x) + else: + urls.append("") + while len(urls) < len(parts): + urls.append("") + + existing = urls[segment_index].strip() if segment_index < len(urls) else "" + chunk_total = len(parts) + + if existing: + storage = get_object_storage() + key = extract_cos_object_key_if_owned(existing) + try: + playback_url = ( + storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC) + if key + else existing + ) + except Exception as exc: + logger.warning("按需 TTS 预签名失败: {}", exc) + playback_url = existing + await manager.send_message( + conversation_id, + { + "type": MessageType.TTS_AUDIO, + "conversation_id": conversation_id, + "data": { + "audio_url": playback_url, + "format": settings.tts_codec, + "index": segment_index, + "total": chunk_total, + "assistant_message_id": assistant_message_id, + "manual": True, + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + return True, "" + + tts_epoch_start = _tts_epoch_value(conversation_id) + url_stored = await _send_tts_audio( + conversation_id, + canon, + chunk_index=segment_index, + chunk_total=chunk_total, + assistant_message_id=assistant_message_id, + tts_epoch_start=tts_epoch_start, + manual=True, + ) + if not url_stored: + return False, "语音合成失败" + + while len(urls) <= segment_index: + urls.append("") + urls[segment_index] = url_stored + msg.tts_audio_urls = urls + await db.commit() + + store = ConversationHistoryStore(db) + await store._sync_redis_best_effort(conversation_id) + return True, "" + + # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── chat_orchestrator = ChatOrchestrator() chat_turn_service = ChatTurnService(chat_orchestrator) @@ -153,6 +263,8 @@ class SegmentStreamState: """会话内分段处理状态(用于并行 ASR + 有序聚合)""" lock: asyncio.Lock = field(default_factory=asyncio.Lock) + #: 本条语音会话最近一次分段上行携带的本轮朗读开关(客户端每段一致即可) + tts_this_turn: bool = False pending_indices: Set[int] = field(default_factory=set) processed_indices: Set[int] = field(default_factory=set) buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) @@ -163,6 +275,43 @@ class SegmentStreamState: _segment_states: Dict[Tuple[str, str], SegmentStreamState] = {} +_user_response_tasks: Dict[str, Set[asyncio.Task]] = {} +_user_response_locks: Dict[str, asyncio.Lock] = {} + + +def _get_user_response_lock(conversation_id: str) -> asyncio.Lock: + lock = _user_response_locks.get(conversation_id) + if lock is None: + lock = asyncio.Lock() + _user_response_locks[conversation_id] = lock + return lock + + +def register_user_response_task(conversation_id: str, task: asyncio.Task) -> None: + tasks = _user_response_tasks.setdefault(conversation_id, set()) + tasks.add(task) + + def _cleanup(done_task: asyncio.Task) -> None: + tasks.discard(done_task) + if not tasks: + _user_response_tasks.pop(conversation_id, None) + _user_response_locks.pop(conversation_id, None) + if done_task.cancelled(): + logger.warning( + "用户回复后台任务被取消 conversation_id={}", + conversation_id, + ) + return + exc = done_task.exception() + if exc: + logger.error( + "用户回复后台任务异常 conversation_id={}: {}", + conversation_id, + exc, + exc_info=True, + ) + + task.add_done_callback(_cleanup) def get_or_create_segment_state( @@ -432,9 +581,13 @@ async def process_audio_segment( audio_base64: str, audio_duration: int, is_last: bool, + *, + tts_this_turn: bool = False, ) -> None: """分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。""" state = get_or_create_segment_state(conversation_id, voice_session_id) + async with state.lock: + state.tts_this_turn = bool(tts_this_turn) logger.info( "process_audio_segment 开始: conversation_id={} voice_session_id={} " "segment_index={} is_last={} duration_s={} audio_b64_len={}", @@ -588,6 +741,7 @@ async def process_audio_segment( ) ready_segments: List[Tuple[int, str, Segment]] = [] + tts_flag_this_voice_session = False async with state.lock: state.processed_indices.add(segment_index) state.buffered_transcripts[segment_index] = ( @@ -602,6 +756,8 @@ async def process_audio_segment( state.consumed_index = next_index next_index += 1 + tts_flag_this_voice_session = bool(state.tts_this_turn) + for _, ordered_text, ordered_segment in ready_segments: await process_user_message( conversation_id=conversation_id, @@ -612,6 +768,7 @@ async def process_audio_segment( user=user, user_message_timestamp=ordered_segment.created_at or user_message_timestamp, + tts_this_turn=tts_flag_this_voice_session, ) except Exception as e: @@ -638,6 +795,48 @@ async def process_audio_segment( # ── 用户消息处理 ──────────────────────────────────────────────── +async def process_persisted_user_segment_response( + *, + conversation_id: str, + user_id: str, + segment_id: str, + tts_this_turn: bool = False, +) -> None: + """后台继续生成已落库用户段落的助手回复;即使 WS 页面退出也要完成落库。""" + lock = _get_user_response_lock(conversation_id) + async with lock: + async with AsyncSessionLocal() as db: + conversation = await db.get(Conversation, conversation_id) + user = await db.get(User, user_id) + segment = await db.get(Segment, segment_id) + if ( + not conversation + or conversation.deleted_at is not None + or conversation.user_id != user_id + or not user + or not segment + or segment.conversation_id != conversation_id + ): + logger.warning( + "跳过用户回复后台任务: conversation_id={} segment_id={} user_id={}", + conversation_id, + segment_id, + user_id, + ) + return + await process_user_message( + conversation_id=conversation_id, + user_message=segment.user_input_text or "", + conversation=conversation, + segment=segment, + db=db, + user=user, + user_message_timestamp=segment.created_at + or conversation.last_message_at, + tts_this_turn=tts_this_turn, + ) + + async def process_user_message( conversation_id: str, user_message: str, @@ -648,6 +847,7 @@ async def process_user_message( user_message_timestamp: Optional[datetime] = None, *, force_skip_tts: bool = False, + tts_this_turn: Optional[bool] = None, ) -> None: """处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。""" store = ConversationHistoryStore(db) @@ -682,20 +882,23 @@ async def process_user_message( get_filled_profile_fields_fn=get_filled_profile_fields, ), ) + responses = turn.messages + skip_tts = bool(turn.skip_tts) + want_voice = bool(tts_this_turn) if tts_this_turn is not None else False + want_tts = want_voice and settings.enable_tts and not skip_tts if agent_summary_enabled(): logger.info( "pipeline.process_user_message duration_ms={:.2f} " "conversation_id={} segment_id={} user_msg_len={} " - "response_segments={} skip_tts={}", + "response_segments={} skip_tts={} want_tts={}", (time.perf_counter() - t_pipeline) * 1000, conversation_id, segment.id, len(user_message or ""), len(turn.messages), turn.skip_tts, + want_tts, ) - responses = turn.messages - skip_tts = bool(turn.skip_tts) segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) _mark_conversation_active(conversation) @@ -750,6 +953,21 @@ async def process_user_message( tts_epoch_start = _tts_epoch_value(conversation_id) n = len(responses) for i, response_text in enumerate(responses): + url_for_segment: Optional[str] = None + if want_tts: + if _tts_epoch_value(conversation_id) != tts_epoch_start: + break + url_for_segment = await _send_tts_audio( + conversation_id, + response_text, + chunk_index=i, + chunk_total=n, + assistant_message_id=ai_msg_id, + tts_epoch_start=tts_epoch_start, + ) + if url_for_segment: + tts_urls.append(url_for_segment) + await manager.send_message( conversation_id, { @@ -764,20 +982,7 @@ async def process_user_message( "timestamp": datetime.now(timezone.utc).isoformat(), }, ) - url = None - if not skip_tts: - if _tts_epoch_value(conversation_id) != tts_epoch_start: - break - url = await _send_tts_audio( - conversation_id, - response_text, - chunk_index=i, - chunk_total=n, - assistant_message_id=ai_msg_id, - tts_epoch_start=tts_epoch_start, - ) - if url: - tts_urls.append(url) + if _tts_epoch_value(conversation_id) != tts_epoch_start: break if i < n - 1: diff --git a/api/app/features/conversation/ws/protocol.md b/api/app/features/conversation/ws/protocol.md index 85676a7..c198c8a 100644 --- a/api/app/features/conversation/ws/protocol.md +++ b/api/app/features/conversation/ws/protocol.md @@ -1,25 +1,35 @@ # WebSocket 消息协议 ## 连接 -- URL: /ws/conversation/{conversation_id}?token={jwt_access_token} -- 鉴权: query 参数 token,JWT access_token + +- URL: `/ws/conversation/{conversation_id}?token={jwt_access_token}` +- 鉴权: query 参数 `token`,JWT `access_token` ## 消息类型 (client → server) -- TEXT: 文本消息 -- AUDIO_SEGMENT: 语音分段 -- AUDIO_MESSAGE: 完整语音消息 -- TRANSCRIBE_ONLY: 仅转写不回复 -- END_CONVERSATION: 结束对话 + +- `TEXT`:文本消息。`data.text` 必填。可选 `data.tts_this_turn`(布尔):为 `true` 且服务端 `ENABLE_TTS` 开启且本轮回避 `skip_tts` 时,对该轮助手回复分段合成 TTS;默认为 `false`/缺省即不合成。**当开启本轮 TTS 时,每个助手分段服务端先推送 `tts_audio` 再推送该段 `agent_response`**,便于客户端先收音频再展示同段文字。 +- `AUDIO_SEGMENT`:语音分段。`data` 含 `audio_base64`、`segment_index`、`voice_session_id` / `client_segment_id`、`is_last`、`duration`。可选同上 `tts_this_turn`。 +- `AUDIO_MESSAGE`:整段音频(单次 ASR + 对话)。同上可选 `tts_this_turn`。 +- `TRANSCRIBE_ONLY`:仅转写不回复 +- `TTS_CANCEL`:取消当前轮未完成的分段合成与下发 +- `TTS_REQUEST`:用户点击某一助手气泡「朗读」且该段尚无 TTS 时下发。`data` 含 `assistant_message_id`(落库 `conversation_messages.id`)、`segment_index`(与该条助手正文按 `[SPLIT]` 分段后的从 0 下标)、可选 `segment_text`(须与该分段正文一致,用于校验)。服务端若该段已有 URL 则只做预签名后推送 `tts_audio`(`data.manual=true`),**不重复合成**。 +- `END_CONVERSATION`:结束对话 +- `PING` / `PONG`:心跳(客户端也可用 JSON `{"type":"ping"}`) ## 消息类型 (server → client) -- TRANSCRIPT: ASR 转写结果 -- AGENT_RESPONSE: AI 回复文本 -- TTS_AUDIO: 语音合成音频 (base64) -- MEMOIR_UPDATE: 回忆录更新通知 -- ERROR: 错误信息 + +- `TRANSCRIPT`: ASR 转写结果 +- `AGENT_RESPONSE`: AI 回复文本分段 +- `TTS_AUDIO`: 语音合成结果(可与 `COS` 签名 URL、`base64` 并存)。按需朗读成功时 `data.manual` 可为 `true`,提示客户端应播放(即使用户未开「本轮 Speak」)。 +- `MEMOIR_UPDATE`: 回忆录更新通知 +- `ERROR`: 错误信息 ## 状态流转 -CONNECT → (TEXT|AUDIO_*) ↔ (TRANSCRIPT|AGENT_RESPONSE|TTS_AUDIO) → END_CONVERSATION + +`CONNECT → (TEXT|AUDIO_*) ↔ (TRANSCRIPT|AGENT_RESPONSE|[TTS_AUDIO]) → END_CONVERSATION` + +同一连接内消息顺序稳定;首轮朗读模式下每一助手分段为 `tts_audio` 先于对应 `agent_response`。 ## 重连 -客户端断连后可用相同 conversation_id 重连,历史消息从 Redis 恢复。 + +客户端断连后可用相同 `conversation_id` 重连,历史消息从 Redis / HTTP 缓存恢复。 diff --git a/api/app/features/conversation/ws/router.py b/api/app/features/conversation/ws/router.py index ef89f8c..ab6a071 100644 --- a/api/app/features/conversation/ws/router.py +++ b/api/app/features/conversation/ws/router.py @@ -34,11 +34,13 @@ from app.features.conversation.ws.pipeline import ( chat_orchestrator, cleanup_segment_states, get_or_create_segment_state, + handle_tts_request_on_demand, memoir_ingest_scheduler, process_audio_segment, process_conversation_segments, - process_user_message, + process_persisted_user_segment_response, register_segment_task, + register_user_response_task, ) from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota @@ -381,7 +383,9 @@ async def websocket_endpoint( ) if msg_type == MessageType.TEXT: - text_message = message.get("data", {}).get("text", "") + data = message.get("data") or {} + text_message = data.get("text", "") + tts_this_turn = bool(data.get("tts_this_turn")) if text_message: can_send, quota_msg = await check_ws_quota( @@ -408,23 +412,21 @@ async def websocket_endpoint( user_id, text_message, ) - user_message_timestamp = conversation.last_message_at await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, text_char_count=len(text_message.strip()), ) - await process_user_message( - conversation_id=conversation_id, - user_message=text_message, - conversation=conversation, - segment=segment, - db=db, - user=user, - user_message_timestamp=segment.created_at - or user_message_timestamp, + task = asyncio.create_task( + process_persisted_user_segment_response( + conversation_id=conversation_id, + user_id=user_id, + segment_id=segment.id, + tts_this_turn=tts_this_turn, + ) ) + register_user_response_task(conversation_id, task) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) @@ -591,6 +593,7 @@ async def websocket_endpoint( audio_base64=audio_base64, audio_duration=audio_duration, is_last=is_last, + tts_this_turn=bool(data.get("tts_this_turn")), ) ) register_segment_task(conversation_id, voice_session_id, task) @@ -599,6 +602,7 @@ async def websocket_endpoint( data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") audio_duration = data.get("duration", 0) + tts_this_turn = bool(data.get("tts_this_turn")) if audio_base64: can_send, quota_msg = await check_ws_quota( @@ -669,7 +673,6 @@ async def websocket_endpoint( audio_duration_seconds=ads if ads > 0 else None, ) ) - user_message_timestamp = conversation.last_message_at await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, @@ -677,16 +680,15 @@ async def websocket_endpoint( ) if asr_text and not asr_text.startswith("转写失败"): - await process_user_message( - conversation_id=conversation_id, - user_message=asr_text, - conversation=conversation, - segment=segment, - db=db, - user=user, - user_message_timestamp=segment.created_at - or user_message_timestamp, + task = asyncio.create_task( + process_persisted_user_segment_response( + conversation_id=conversation_id, + user_id=user_id, + segment_id=segment.id, + tts_this_turn=tts_this_turn, + ) ) + register_user_response_task(conversation_id, task) else: await manager.send_message( conversation_id, @@ -756,6 +758,51 @@ async def websocket_endpoint( elif msg_type == MessageType.TTS_CANCEL: bump_tts_cancel_epoch(conversation_id) + elif msg_type == MessageType.TTS_REQUEST: + data = message.get("data") or {} + aid = data.get("assistant_message_id") or data.get( + "assistantMessageId" + ) + if not aid or not str(aid).strip(): + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "缺少助手消息 id"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + continue + try: + seg_idx = int( + data.get("segment_index", data.get("segmentIndex", 0)) + ) + except (TypeError, ValueError): + seg_idx = 0 + st = data.get("segment_text") or data.get("segmentText") + st_val: str | None + if st is None: + st_val = None + else: + st_val = str(st).strip() or None + ok, err_msg = await handle_tts_request_on_demand( + conversation_id=conversation_id, + user_id=user_id, + assistant_message_id=str(aid).strip(), + segment_index=seg_idx, + segment_text=st_val, + db=db, + ) + if not ok: + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": err_msg or "朗读请求失败"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + elif msg_type == MessageType.END_CONVERSATION: await conversation_service.end(conversation_id, user_id) diff --git a/api/app/features/memoir/service.py b/api/app/features/memoir/service.py index b4520f0..c40700d 100644 --- a/api/app/features/memoir/service.py +++ b/api/app/features/memoir/service.py @@ -81,7 +81,6 @@ class MemoirService: "relevant_chunks": [], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } bundle = await self._memory.retrieve(user_id, query, top_k=top_k) diff --git a/api/app/features/memoir/state_service.py b/api/app/features/memoir/state_service.py index b32d6f2..8246d9e 100644 --- a/api/app/features/memoir/state_service.py +++ b/api/app/features/memoir/state_service.py @@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.agents.stage_constants import ( + allowed_slot_names_for_stage, chat_bucket, normalize_chat_stage, ) @@ -136,6 +137,8 @@ async def update_slot( fallback=current_from_db, log_context={"user_id": user_id}, ) + if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db): + return coerce_memoir_state(state) slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None @@ -292,6 +295,8 @@ def update_slot_sync( fallback=current_from_db, log_context={"user_id": user_id}, ) + if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db): + return coerce_memoir_state(state) slots = _slots_snapshot_for_merge( state.slots if isinstance(state.slots, dict) else None diff --git a/api/app/features/memoir/story_pipeline_sync.py b/api/app/features/memoir/story_pipeline_sync.py index 1e63121..93221c4 100644 --- a/api/app/features/memoir/story_pipeline_sync.py +++ b/api/app/features/memoir/story_pipeline_sync.py @@ -11,18 +11,17 @@ import re import time import uuid from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from typing import Any from sqlalchemy import func, select from sqlalchemy.orm import Session, joinedload from app.agents.memoir.narrative_agent import NarrativeAgent -from app.agents.memoir.prompts import ( - format_evidence_chunks_for_prompt, - format_narrative_user_content, -) +from app.agents.memoir.prompts import format_narrative_user_content from app.agents.memoir.story_route_agent import ( APPEND_FIRST_CHAPTER_CATEGORIES, + FALLBACK_NEW_STORY_REASONS, PLAN_BATCH_MAX_SEGMENTS, StoryBatchPlan, StoryRouteAgent, @@ -59,6 +58,7 @@ from app.features.memoir.repo import ( mark_chapter_dirty_sync, reorder_chapter_story_links_by_life_order_sync, ) +from app.features.memory.evidence_format import format_evidence_chunks_for_prompt from app.features.story.models import Story, StoryVersion from app.features.story.sync_write import ( append_story_version_sync, @@ -72,6 +72,23 @@ from app.features.story.sync_write import ( logger = get_logger(__name__) +@dataclass +class StoryPipelineResult: + """Phase2 故事管线结果。 + + - 正常写入:``deferred=False``,``chapter`` 非空。 + - 低置信延迟:``deferred=True``,``chapter`` 为 None;调用方应把 ``defer_segment_ids`` + 标记为延迟态,不要置 ``narrated/processed``,也不要触发后置任务。 + """ + + chapter: Chapter | None + needs_cover: bool + dispatch_ids: set[str] + deferred: bool = False + defer_reason: str | None = None + defer_segment_ids: list[str] = field(default_factory=list) + + def _dialogue_lineage_dict_for_segment_ids( category_segments: list, segment_ids: list[str], @@ -99,7 +116,7 @@ def _dialogue_lineage_dict_for_segment_ids( def _evidence_link_ids( evidence: dict, -) -> tuple[list[str], list[str], list[str], list[str]]: +) -> tuple[list[str], list[str], list[str]]: """从 MemoryService.retrieve 结果提取稳定 ID 列表。""" chunks: list[str] = [] for c in evidence.get("relevant_chunks") or []: @@ -109,15 +126,11 @@ def _evidence_link_ids( for f in evidence.get("relevant_facts") or []: if isinstance(f, dict) and f.get("id"): facts.append(str(f["id"])) - timelines: list[str] = [] - for e in evidence.get("timeline_hints") or []: - if isinstance(e, dict) and e.get("id"): - timelines.append(str(e["id"])) summaries: list[str] = [] for s in evidence.get("relevant_summaries") or []: if isinstance(s, dict) and s.get("id"): summaries.append(str(s["id"])) - return chunks, facts, timelines, summaries + return chunks, facts, summaries def _story_prompt_meta_for_lineage( @@ -126,14 +139,13 @@ def _story_prompt_meta_for_lineage( memoir_correlation_id: str | None, top_k: int, ) -> dict: - c, f, t, s = _evidence_link_ids(evidence) + c, f, s = _evidence_link_ids(evidence) return { "memoir_retrieval": { "correlation_id": memoir_correlation_id, "top_k": top_k, "chunk_ids": c, "fact_ids": f, - "timeline_event_ids": t, "summary_ids": s, } } @@ -150,13 +162,13 @@ def _persist_story_lineage_sync( dialogue_lineage: dict | None = None, ) -> None: """写入 StoryEvidenceLink + 本版本 prompt_meta(可审计检索闭包)。""" - c, f, t, s = _evidence_link_ids(evidence) + c, f, s = _evidence_link_ids(evidence) replace_story_evidence_links_sync( session, story_id=story_id, chunk_ids=c, fact_ids=f, - timeline_event_ids=t, + timeline_event_ids=[], summary_ids=s, ) version.prompt_meta = _story_prompt_meta_for_lineage( @@ -669,6 +681,7 @@ def _resolve_append_target( route_decision == "new_story" and chapter_category in APPEND_FIRST_CHAPTER_CATEGORIES and candidate_stories + and decision_source not in FALLBACK_NEW_STORY_REASONS and len(oral_norm) <= int(settings.memoir_story_route_append_guardrail_oral_chars) ): @@ -959,9 +972,10 @@ def run_story_pipeline_for_category_batch( memoir_correlation_id: str | None = None, llm_fast: Any | None = None, memory_evidence: dict | None = None, -) -> tuple[Chapter | None, bool, set[str]]: - """ - 返回 (chapter, needs_cover_enqueue, story_ids_to_dispatch_after_commit)。 +) -> StoryPipelineResult: + """运行某 chapter_category 的 Phase2 写入管线。 + + 返回 :class:`StoryPipelineResult`。低置信路由会被延迟而不创建 Story/Chapter。 """ pipeline_phase_timings: dict[str, float] = {} narrative_agent = NarrativeAgent() @@ -992,7 +1006,6 @@ def run_story_pipeline_for_category_batch( "relevant_chunks": [], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } ev_elapsed = time.perf_counter() - _t_ev @@ -1082,8 +1095,46 @@ def run_story_pipeline_for_category_batch( valid_story_ids=valid_ids, story_meta=story_meta, ) + + single_route: Any = None + if plan is None: + single_route = route_agent.decide( + chapter_category=chapter_category, + chapter_title=title, + batch_transcript=route_transcript, + candidate_stories=candidates, + llm=llm_route, + valid_story_ids=valid_ids, + story_meta=story_meta, + ) pipeline_phase_timings["route"] = time.perf_counter() - _t0 + if ( + plan is None + and single_route is not None + and single_route.reason in FALLBACK_NEW_STORY_REASONS + and bool(settings.memoir_route_defer_enabled) + ): + defer_ids = [str(s.id) for s in category_segments] + logger.info( + "event=memoir_pipeline_route_deferred memoir_correlation_id={} user_id={} " + "chapter_category={} segment_count={} reason={} " + "msg=Phase2 路由低置信,本批 segment 进入延迟池", + memoir_correlation_id or "", + user_id, + chapter_category, + len(defer_ids), + single_route.reason, + ) + return StoryPipelineResult( + chapter=None, + needs_cover=False, + dispatch_ids=set(), + deferred=True, + defer_reason=str(single_route.reason), + defer_segment_ids=defer_ids, + ) + chapter = _ensure_chapter_record( session, user_id=user_id, @@ -1118,17 +1169,12 @@ def run_story_pipeline_for_category_batch( fidelity_llm=llm_fidelity, ) else: - route = route_agent.decide( - chapter_category=chapter_category, - chapter_title=title, - batch_transcript=route_transcript, - candidate_stories=candidates, - llm=llm_route, - valid_story_ids=valid_ids, - story_meta=story_meta, + route = single_route + decision_source = ( + route.reason + if route.reason in FALLBACK_NEW_STORY_REASONS + else ("fallback_no_llm" if not llm_route else "single_decide") ) - - decision_source = "fallback_no_llm" if not llm else "single_decide" target_story_id, existing_for_narrative, decision_source = ( _resolve_append_target( session, @@ -1199,4 +1245,8 @@ def run_story_pipeline_for_category_batch( timing_parts, ) - return chapter, needs_cover, dispatch_ids + return StoryPipelineResult( + chapter=chapter, + needs_cover=needs_cover, + dispatch_ids=dispatch_ids, + ) diff --git a/api/app/features/memory/chunker.py b/api/app/features/memory/chunker.py index f110370..8a35f65 100644 --- a/api/app/features/memory/chunker.py +++ b/api/app/features/memory/chunker.py @@ -13,10 +13,14 @@ def chunk_transcript( text = text.strip() if len(text) <= max_chars: return [text] if text else [] + if max_chars <= 0: + raise ValueError("max_chars must be positive") + if overlap_chars < 0: + raise ValueError("overlap_chars cannot be negative") + overlap = min(overlap_chars, max_chars - 1) chunks: list[str] = [] start = 0 - step = max_chars - overlap_chars while start < len(text): end = start + max_chars @@ -31,6 +35,12 @@ def chunk_transcript( break if chunk.strip(): chunks.append(chunk.strip()) - start += len(chunk) if chunk else step + if not chunk: + start += max_chars - overlap + continue + next_start = end - overlap + if next_start <= start: + next_start = start + len(chunk) + start = next_start return chunks diff --git a/api/app/features/memory/embedding_scheduler.py b/api/app/features/memory/embedding_scheduler.py new file mode 100644 index 0000000..f182323 --- /dev/null +++ b/api/app/features/memory/embedding_scheduler.py @@ -0,0 +1,28 @@ +"""Memory embedding scheduling boundary.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class MemoryEmbeddingRequest: + user_id: str + source_id: str + memoir_correlation_id: str | None = None + + +class MemoryEmbeddingScheduler: + """Adapter around the Celery embedding task name and queue policy.""" + + def schedule(self, request: MemoryEmbeddingRequest) -> str | None: + from app.tasks.memory_enrichment_tasks import schedule_memory_embedding + + return schedule_memory_embedding( + request.user_id, + request.source_id, + memoir_correlation_id=request.memoir_correlation_id, + ) + + +__all__ = ["MemoryEmbeddingRequest", "MemoryEmbeddingScheduler"] diff --git a/api/app/features/memory/embedding_service.py b/api/app/features/memory/embedding_service.py new file mode 100644 index 0000000..5f89a0a --- /dev/null +++ b/api/app/features/memory/embedding_service.py @@ -0,0 +1,182 @@ +"""Memory embedding service boundary.""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger +from app.features.memory.repo import ( + list_chunks_for_source, + set_chunk_embedding_status, + set_source_embedding_status, + update_chunk_embedding, +) +from app.ports.embedding import EmbeddingProvider + +logger = get_logger(__name__) + + +def _short_error(exc: BaseException | str, *, max_chars: int = 500) -> str: + text = str(exc) + if len(text) > max_chars: + return text[: max_chars - 3] + "..." + return text + + +async def _commit_if_available(db: AsyncSession) -> None: + commit = getattr(db, "commit", None) + if commit is not None: + await commit() + + +class MemoryEmbeddingService: + """Embeds persisted memory chunks and records source/chunk status.""" + + def __init__( + self, + db: AsyncSession, + *, + embedding_provider: EmbeddingProvider | None = None, + ) -> None: + self._db = db + self._embedding = embedding_provider + + async def embed_source( + self, + user_id: str, + source_id: str, + *, + raise_on_failure: bool = False, + ) -> dict: + chunks = await list_chunks_for_source( + self._db, + user_id=user_id, + source_id=source_id, + include_excluded=True, + ) + if not chunks: + await set_source_embedding_status( + self._db, + source_id=source_id, + user_id=user_id, + status="skipped", + error="no_chunks", + ) + await _commit_if_available(self._db) + return {"status": "skipped", "reason": "no_chunks", "chunks": 0} + + if self._embedding is None: + err = "embedding_provider_missing" + await self._mark_failed(user_id, source_id, [c.id for c in chunks], err) + if raise_on_failure: + raise RuntimeError(err) + return {"status": "failed", "error": err, "chunks": len(chunks)} + + await set_source_embedding_status( + self._db, + source_id=source_id, + user_id=user_id, + status="running", + error=None, + ) + await _commit_if_available(self._db) + + try: + texts = [c.content for c in chunks] + raw_embeddings = await self._embedding.embed_texts(texts) + embeddings = list(raw_embeddings or []) + except Exception as e: + err = _short_error(e) + await self._mark_failed(user_id, source_id, [c.id for c in chunks], err) + logger.warning( + "event=memory_embedding_failed user_id={} source_id={} chunks={} exc_type={} exc={}", + user_id, + source_id, + len(chunks), + type(e).__name__, + err, + ) + if raise_on_failure: + raise + return {"status": "failed", "error": err, "chunks": len(chunks)} + + vectors_written = 0 + failed_chunk_ids: list[str] = [] + for chunk, emb in zip(chunks, embeddings, strict=False): + if emb: + vectors_written += 1 + await update_chunk_embedding(self._db, chunk.id, emb) + else: + failed_chunk_ids.append(chunk.id) + await set_chunk_embedding_status( + self._db, + chunk.id, + status="failed", + error="empty_embedding", + ) + + if len(embeddings) != len(chunks): + missing = chunks[len(embeddings) :] + failed_chunk_ids.extend(c.id for c in missing) + for chunk in missing: + await set_chunk_embedding_status( + self._db, + chunk.id, + status="failed", + error="embedding_count_mismatch", + ) + logger.warning( + "event=memory_embedding_count_mismatch user_id={} source_id={} chunks={} embeddings={}", + user_id, + source_id, + len(chunks), + len(embeddings), + ) + + status = "success" + error = None + if failed_chunk_ids: + status = "partial" if vectors_written else "failed" + error = f"failed_chunks={len(failed_chunk_ids)}" + await set_source_embedding_status( + self._db, + source_id=source_id, + user_id=user_id, + status=status, + error=error, + ) + await _commit_if_available(self._db) + if status == "failed" and raise_on_failure: + raise RuntimeError(error or "embedding_failed") + return { + "status": status, + "chunks": len(chunks), + "vectors_written": vectors_written, + "failed_chunks": failed_chunk_ids, + } + + async def _mark_failed( + self, + user_id: str, + source_id: str, + chunk_ids: list[str], + error: str, + ) -> None: + await set_source_embedding_status( + self._db, + source_id=source_id, + user_id=user_id, + status="failed", + error=error, + ) + for chunk_id in chunk_ids: + await set_chunk_embedding_status( + self._db, + chunk_id, + status="failed", + error=error, + ) + await _commit_if_available(self._db) + + +__all__ = ["MemoryEmbeddingService"] diff --git a/api/app/features/memory/enrichment.py b/api/app/features/memory/enrichment.py index ad05ba1..135d355 100644 --- a/api/app/features/memory/enrichment.py +++ b/api/app/features/memory/enrichment.py @@ -12,7 +12,6 @@ from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.core.langchain_llm import ainvoke_json_object from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.features.memory.enrichment_pipeline import ( @@ -23,24 +22,37 @@ from app.features.memory.enrichment_pipeline import ( from app.features.memory.llm_schemas import ( EnrichmentPayload, enrichment_payload_to_fact_dicts, - parse_json_payload, ) from app.features.memory.models import MemoryChunk, MemorySource from app.features.memory.repo import ( create_memory_fact, create_memory_summary, + set_source_enrichment_status, ) from app.features.user.models import User logger = get_logger(__name__) +def _short_error(exc: BaseException | str, *, max_chars: int = 500) -> str: + text = str(exc) + if len(text) > max_chars: + return text[: max_chars - 3] + "..." + return text + + +async def _commit_if_available(db: AsyncSession) -> None: + commit = getattr(db, "commit", None) + if commit is not None: + await commit() + + def _lineage_snapshot_from_source(source: MemorySource | None) -> dict | None: raw = getattr(source, "lineage_json", None) if source else None return raw if isinstance(raw, dict) and raw else None -def _resolve_llm() -> Any | None: +def _resolve_gateway_llm() -> Any | None: try: return LlmGateway().langchain_llm_for( LlmUseCase("memory.enrichment", fast=True) @@ -89,17 +101,11 @@ async def _run_enrichment_llm_async( if not llm or not (numbered or "").strip(): return None prompt = _enrichment_prompt(numbered, narrator_label) - try: - raw = await ainvoke_json_object( - llm, - prompt, - max_tokens=8192, - agent="memory.enrichment", - ) - return parse_json_payload(raw, EnrichmentPayload) - except (TypeError, ValueError) as e: - logger.warning("enrichment LLM async 解析失败: {}", e) - return None + return await LlmGateway().json_object( + prompt, + EnrichmentPayload, + use_case=LlmUseCase("memory.enrichment", fast=True, max_tokens=8192), + ) async def enrich_memory_after_ingest_async( @@ -107,15 +113,44 @@ async def enrich_memory_after_ingest_async( user_id: str, source_id: str, llm: Any | None = None, -) -> None: + *, + raise_on_failure: bool = False, +) -> dict: from app.core.config import settings if not settings.memory_enrichment_enabled: - return + await set_source_enrichment_status( + db, + source_id=source_id, + user_id=user_id, + status="skipped", + error="disabled", + ) + await _commit_if_available(db) + return {"status": "skipped", "reason": "disabled"} if llm is None: - llm = _resolve_llm() + llm = _resolve_gateway_llm() if not llm: - return + err = "llm_unavailable" + await set_source_enrichment_status( + db, + source_id=source_id, + user_id=user_id, + status="failed", + error=err, + ) + await _commit_if_available(db) + if raise_on_failure: + raise RuntimeError(err) + return {"status": "failed", "error": err} + await set_source_enrichment_status( + db, + source_id=source_id, + user_id=user_id, + status="running", + error=None, + ) + await _commit_if_available(db) narrator_name: str | None = None u_row = await db.get(User, user_id) if u_row and (u_row.nickname or "").strip(): @@ -128,7 +163,15 @@ async def enrich_memory_after_ingest_async( result = await db.execute(stmt) chunks = list(result.unique().scalars().all()) if not chunks: - return + await set_source_enrichment_status( + db, + source_id=source_id, + user_id=user_id, + status="skipped", + error="no_chunks", + ) + await _commit_if_available(db) + return {"status": "skipped", "reason": "no_chunks"} src_row = await db.get(MemorySource, source_id) lineage_snapshot = _lineage_snapshot_from_source(src_row) chunk_ids = [c.id for c in chunks] @@ -139,9 +182,30 @@ async def enrich_memory_after_ingest_async( ) narrator_label = (narrator_name or "").strip() or "叙述者" - payload = await _run_enrichment_llm_async(llm, numbered, narrator_label) - if payload is None: - return + try: + payload = await _run_enrichment_llm_async(llm, numbered, narrator_label) + if payload is None: + raise ValueError("empty_enrichment_payload") + except Exception as e: + err = _short_error(e) + await set_source_enrichment_status( + db, + source_id=source_id, + user_id=user_id, + status="failed", + error=err, + ) + await _commit_if_available(db) + logger.warning( + "event=memory_enrichment_llm_failed user_id={} source_id={} exc_type={} exc={}", + user_id, + source_id, + type(e).__name__, + err, + ) + if raise_on_failure: + raise + return {"status": "failed", "error": err} session_summary_text = str(payload.summary or "").strip() if session_summary_text: @@ -175,3 +239,19 @@ async def enrich_memory_after_ingest_async( status="confirmed", lineage_json=lineage_snapshot, ) + + await set_source_enrichment_status( + db, + source_id=source_id, + user_id=user_id, + status="success", + error=None, + ) + await _commit_if_available(db) + return { + "status": "success", + "source_id": source_id, + "chunks": len(chunks), + "facts": len(seen), + "summary": bool(session_summary_text), + } diff --git a/api/app/features/memory/evidence.py b/api/app/features/memory/evidence.py index f27a053..d015db9 100644 --- a/api/app/features/memory/evidence.py +++ b/api/app/features/memory/evidence.py @@ -1,7 +1,7 @@ """ 证据包组装:跨 memory + story 的检索结果合并(业务层,非纯 repo)。 -Memory evidence 只保留 async 单链路:chunk 原文为首要证据,结构化事实/时间线/ +Memory evidence 只保留 async 单链路:chunk 原文为首要证据,结构化事实/ 摘要/故事均按本次 query 命中进入 evidence,不再做 rolling/recent 历史降级。 """ @@ -12,7 +12,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.features.memory.repo import ( list_summaries_for_evidence_async, search_facts_for_user_async, - search_timeline_events_for_user_async, ) from app.features.story.repo import list_recent_stories_for_evidence @@ -20,7 +19,6 @@ EMPTY_EVIDENCE_BUNDLE: dict = { "relevant_chunks": [], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } @@ -38,19 +36,6 @@ def _facts_to_dicts(facts) -> list[dict]: ] -def _timeline_to_dicts(events) -> list[dict]: - return [ - { - "id": e.id, - "event_year": e.event_year, - "event_date": e.event_date, - "title": e.title, - "description": e.description, - } - for e in events - ] - - def _stories_to_dicts(story_rows) -> list[dict]: return [ { @@ -69,7 +54,6 @@ async def fetch_evidence_metadata_async( ) -> dict: """非 chunk 证据(async)。""" facts = await search_facts_for_user_async(db, user_id, q, top_k) - events = await search_timeline_events_for_user_async(db, user_id, q, top_k) relevant_summaries = await list_summaries_for_evidence_async( db, user_id=user_id, q=q, limit=top_k ) @@ -78,7 +62,6 @@ async def fetch_evidence_metadata_async( ) return { "relevant_facts": _facts_to_dicts(facts), - "timeline_hints": _timeline_to_dicts(events), "relevant_summaries": relevant_summaries, "relevant_stories": _stories_to_dicts(story_rows), } diff --git a/api/app/features/memory/evidence_format.py b/api/app/features/memory/evidence_format.py index 7108e67..278f642 100644 --- a/api/app/features/memory/evidence_format.py +++ b/api/app/features/memory/evidence_format.py @@ -81,15 +81,11 @@ def format_user_memory_for_chat_display( def format_evidence_chunks_for_chat_prompt(evidence: dict) -> str: - """聊天访谈专用:将检索 bundle 格式化为带编号引用与安全说明的短文本。 - - 与 `format_evidence_chunks_for_prompt` 并行存在;memoir/叙事流水线仍用后者,避免牵连成稿。 - """ + """聊天访谈专用:将检索 bundle 格式化为带编号引用与安全说明的短文本.""" chunks = evidence.get("relevant_chunks") or [] chunks = dedupe_evidence_chunk_rows(chunks[:10]) summaries = evidence.get("relevant_summaries") or [] facts = evidence.get("relevant_facts") or [] - timeline = evidence.get("timeline_hints") or [] stories = evidence.get("relevant_stories") or [] header = ( @@ -143,20 +139,6 @@ def format_evidence_chunks_for_chat_prompt(evidence: dict) -> str: safe = format_user_memory_for_chat_display(fact_line) lines.append(f"[M{n}] {safe}") - for t in timeline[:5]: - if isinstance(t, dict): - title = (t.get("title") or "").strip() - year = t.get("event_year") - desc = (t.get("description") or "").strip() - line = " ".join( - x for x in (str(year) if year is not None else "", title, desc) if x - ) - if not line: - continue - n += 1 - safe = format_user_memory_for_chat_display(line) - lines.append(f"[M{n}] {safe}") - for st in stories[:3]: if isinstance(st, dict): title = (st.get("title") or "").strip() @@ -175,15 +157,11 @@ def format_evidence_chunks_for_chat_prompt(evidence: dict) -> str: def format_evidence_chunks_for_prompt(evidence: dict) -> str: - """将 MemoryService.retrieve 结果格式化为简短文本,供叙事与访谈 prompt 使用。 - - 包含 chunks、摘要(若有)、confirmed facts、timeline、故事摘要(若有)。 - """ + """将 MemoryService.retrieve 结果格式化为简短文本,供叙事与访谈 prompt 使用.""" chunks = evidence.get("relevant_chunks") or [] chunks = dedupe_evidence_chunk_rows(chunks[:10]) summaries = evidence.get("relevant_summaries") or [] facts = evidence.get("relevant_facts") or [] - timeline = evidence.get("timeline_hints") or [] stories = evidence.get("relevant_stories") or [] parts: list[str] = [] for c in chunks: @@ -212,16 +190,6 @@ def format_evidence_chunks_for_prompt(evidence: dict) -> str: parts.append(f"{subj}:{pred}") else: parts.append(f"{getattr(f, 'subject', '')}:{getattr(f, 'predicate', '')}") - for t in timeline[:5]: - if isinstance(t, dict): - title = (t.get("title") or "").strip() - year = t.get("event_year") - desc = (t.get("description") or "").strip() - line = " ".join( - x for x in (str(year) if year is not None else "", title, desc) if x - ) - if line: - parts.append(line) for st in stories[:3]: if isinstance(st, dict): title = (st.get("title") or "").strip() diff --git a/api/app/features/memory/extractor.py b/api/app/features/memory/extractor.py deleted file mode 100644 index 48a208d..0000000 --- a/api/app/features/memory/extractor.py +++ /dev/null @@ -1,98 +0,0 @@ -"""从 transcript 块中抽取结构化事实(async LLM + JSON)。""" - -from __future__ import annotations - -from typing import Any - -from app.core.langchain_llm import ainvoke_json_object -from app.core.llm_gateway import LlmGateway, LlmUseCase -from app.core.logging import get_logger -from app.features.memory.llm_schemas import ( - FactsExtractionPayload, - facts_payload_to_dicts, - parse_json_payload, -) - -logger = get_logger(__name__) - - -def _max_transcript_chars() -> int: - from app.core.config import settings - - return settings.memory_enrichment_max_chars - - -def _facts_extraction_instructions(narrator_label: str) -> str: - return ( - "你是回忆录事实抽取助手。用户正在口述人生回忆,所有内容默认是**过去发生的事**," - "而非当前或未来计划(除非原文明确说「现在」「打算」「准备将要」等)。\n\n" - "## 抽取规则\n" - "1. subject 必须用明确的人名或固定称谓:\n" - f" - 叙述者本人统一用「{narrator_label}」\n" - " - 其他人用全名或稳定专名(如「王伟」),禁止用「他」「她」「我」「我们大伙」等代词作 subject;" - "若代词在上下文中可唯一解析为某人,则 subject 写该人姓名/专名\n" - "2. 事件、职务变动、地点迁移等一律按**过去回忆**理解;travel/调动/命令类表述勿写成「即将要做」" - "除非原文明确为未来时态\n" - "3. 若可推断大约年代或人生阶段,将 approximate_era 写入 object_json(与 value 等字段并存)," - '例如 "1990年代"、"2001年"、"退休后"、"30岁前后"\n' - "4. fact_type: person|event|relation|place|milestone\n" - "5. predicate:简短中文谓语(如「出生地」「担任职务」「调往」)\n" - "6. object_json:字符串或对象;可含 value、approximate_era 等\n" - "7. confidence 0..1;source_chunk_id 必须等于某段 [chunk_id=...] 中的 id\n\n" - '只输出 JSON:{"facts":[...]},无事实则 {"facts":[]}。\n\n' - ) - - -async def extract_facts_from_transcript_async( - llm: Any, - numbered_blocks: str, - *, - narrator_name: str | None = None, -) -> list[dict]: - """带 chunk_id 标记的文本 → 事实列表。""" - if not llm or not (numbered_blocks or "").strip(): - return [] - text = numbered_blocks.strip()[: _max_transcript_chars()] - narrator_label = (narrator_name or "").strip() or "叙述者" - prompt = _facts_extraction_instructions(narrator_label) + text - try: - raw = await ainvoke_json_object( - llm, - prompt, - max_tokens=4096, - agent="memory.extract_facts_async", - ) - parsed = parse_json_payload(raw, FactsExtractionPayload) - if parsed is None: - return [] - return facts_payload_to_dicts(parsed) - except (TypeError, ValueError) as e: - logger.warning("extract_facts_from_transcript_async 解析失败: {}", e) - return [] - - -async def extract_facts(chunk_text: str, *, user_id: str) -> list[dict]: - """兼容旧接口:单块文本(无 chunk id 时传空 source_chunk_id)。""" - from app.core.db import AsyncSessionLocal - from app.features.user.models import User - - llm = LlmGateway().langchain_llm_for( - LlmUseCase("memory.extract_facts.compat", fast=True) - ) - narrator_name: str | None = None - try: - async with AsyncSessionLocal() as db: - u = await db.get(User, user_id) - if u and (u.nickname or "").strip(): - narrator_name = (u.nickname or "").strip() - except Exception: - pass - - blocks = f"[chunk_id=null]\n{chunk_text}" - facts = await extract_facts_from_transcript_async( - llm, blocks, narrator_name=narrator_name - ) - for f in facts: - if f.get("source_chunk_id") in (None, "null", ""): - f["source_chunk_id"] = None - return facts diff --git a/api/app/features/memory/ingest_service.py b/api/app/features/memory/ingest_service.py index 176d4ba..c791e1a 100644 --- a/api/app/features/memory/ingest_service.py +++ b/api/app/features/memory/ingest_service.py @@ -10,6 +10,11 @@ from app.features.conversation.lineage_schemas import ( primary_user_message_id_from_lineage, ) from app.features.memory.chunker import chunk_transcript +from app.features.memory.embedding_scheduler import ( + MemoryEmbeddingRequest, + MemoryEmbeddingScheduler, +) +from app.features.memory.embedding_service import MemoryEmbeddingService from app.features.memory.enrichment_scheduler import ( MemoryEnrichmentRequest, MemoryEnrichmentScheduler, @@ -17,7 +22,6 @@ from app.features.memory.enrichment_scheduler import ( from app.features.memory.repo import ( create_chunk, create_source, - update_chunk_embedding, ) from app.ports.embedding import EmbeddingProvider @@ -32,10 +36,12 @@ class MemoryIngestService: db: AsyncSession, *, embedding_provider: EmbeddingProvider | None = None, + embedding_scheduler: MemoryEmbeddingScheduler | None = None, enrichment_scheduler: MemoryEnrichmentScheduler | None = None, ) -> None: self._db = db self._embedding = embedding_provider + self._embedding_scheduler = embedding_scheduler or MemoryEmbeddingScheduler() self._enrichment_scheduler = enrichment_scheduler or MemoryEnrichmentScheduler() async def ingest_transcript( @@ -74,19 +80,17 @@ class MemoryIngestService: chunk_records.append((chunk.id, content)) await self._db.flush() - - vectors_written = 0 - if self._embedding and chunk_records: - texts = [content for _, content in chunk_records] - embeddings = await self._embedding.embed_texts(texts) - for (chunk_id, _), emb in zip( - chunk_records, embeddings, strict=False - ): - if emb: - vectors_written += 1 - await update_chunk_embedding(self._db, chunk_id, emb) - await self._db.commit() + + embedding_result = await MemoryEmbeddingService( + self._db, + embedding_provider=self._embedding, + ).embed_source(user_id, source.id) + embedding_task_id = self._schedule_embedding_retry_if_needed( + user_id, + source.id, + embedding_result, + ) emb_ok = self._embedding.is_available() if self._embedding else False enrichment_task_id = self._enrichment_scheduler.schedule( MemoryEnrichmentRequest(user_id=user_id, source_id=source.id) @@ -94,13 +98,16 @@ class MemoryIngestService: logger.info( "event=memory_ingest_done user_id={} conversation_id={} source_id={} " - "chunks={} vectors_written={} embedding_available={} enrichment_enabled={} enrichment_task_id={}", + "chunks={} vectors_written={} embedding_status={} embedding_available={} " + "embedding_task_id={} enrichment_enabled={} enrichment_task_id={}", user_id, conversation_id, source.id, len(chunk_records), - vectors_written, + embedding_result.get("vectors_written", 0), + embedding_result.get("status"), emb_ok, + embedding_task_id, settings.memory_enrichment_enabled, enrichment_task_id, ) @@ -152,17 +159,29 @@ class MemoryIngestService: chunk_records.append((chunk.id, content)) await self._db.flush() + await self._db.commit() vectors_written = 0 - if self._embedding and chunk_records: - texts = [content for _, content in chunk_records] - embeddings = await self._embedding.embed_texts(texts) - for (chunk_id, _), emb in zip(chunk_records, embeddings, strict=False): - if emb: - vectors_written += 1 - await update_chunk_embedding(self._db, chunk_id, emb) + embedding_retry_task_ids: list[str] = [] + embedding_statuses: dict[str, int] = {} + embedding_service = MemoryEmbeddingService( + self._db, + embedding_provider=self._embedding, + ) + for source_id in source_ids: + result = await embedding_service.embed_source(user_id, source_id) + vectors_written += int(result.get("vectors_written") or 0) + status = str(result.get("status") or "unknown") + embedding_statuses[status] = embedding_statuses.get(status, 0) + 1 + task_id = self._schedule_embedding_retry_if_needed( + user_id, + source_id, + result, + memoir_correlation_id=memoir_correlation_id, + ) + if task_id: + embedding_retry_task_ids.append(task_id) - await self._db.commit() emb_ok = self._embedding.is_available() if self._embedding else False task_ids = self._enrichment_scheduler.schedule_many( user_id, @@ -172,16 +191,38 @@ class MemoryIngestService: logger.info( "event=memory_ingest_batch_done user_id={} sources={} chunks={} " - "vectors_written={} embedding_available={} enrichment_enabled={} enrichment_tasks={}", + "vectors_written={} embedding_available={} embedding_statuses={} " + "embedding_retry_tasks={} enrichment_enabled={} enrichment_tasks={}", user_id, len(source_ids), len(chunk_records), vectors_written, emb_ok, + embedding_statuses, + len(embedding_retry_task_ids), settings.memory_enrichment_enabled, len(task_ids), ) return source_ids + def _schedule_embedding_retry_if_needed( + self, + user_id: str, + source_id: str, + embedding_result: dict, + *, + memoir_correlation_id: str | None = None, + ) -> str | None: + status = str(embedding_result.get("status") or "") + if status not in {"failed", "partial"}: + return None + return self._embedding_scheduler.schedule( + MemoryEmbeddingRequest( + user_id=user_id, + source_id=source_id, + memoir_correlation_id=memoir_correlation_id, + ) + ) + __all__ = ["MemoryIngestService"] diff --git a/api/app/features/memory/llm_schemas.py b/api/app/features/memory/llm_schemas.py index 67b57b8..cec8d92 100644 --- a/api/app/features/memory/llm_schemas.py +++ b/api/app/features/memory/llm_schemas.py @@ -2,13 +2,10 @@ from __future__ import annotations -import json -from typing import Any, TypeVar +from typing import Any from pydantic import BaseModel, Field, field_validator -TModel = TypeVar("TModel", bound=BaseModel) - class ExtractedFactItem(BaseModel): fact_type: str = "event" @@ -38,49 +35,6 @@ class EnrichmentPayload(BaseModel): facts: list[ExtractedFactItem] = Field(default_factory=list) -class SessionSummaryPayload(BaseModel): - summary: str = "" - - -class RollingSummaryPayload(BaseModel): - rolling_summary: str = "" - - -class TimelineEventItem(BaseModel): - event_year: int | None = None - event_date: str | None = None - title: str = "" - description: str | None = None - source_fact_ids: list[str] = Field(default_factory=list) - - @field_validator("source_fact_ids", mode="before") - @classmethod - def _coerce_sf(cls, v: object) -> list[str]: - if v is None: - return [] - if isinstance(v, str): - return [v] if v else [] - if isinstance(v, list): - return [str(x) for x in v if x] - return [] - - -class TimelineEventsPayload(BaseModel): - events: list[TimelineEventItem] = Field(default_factory=list) - - -def parse_json_payload(raw: str, model: type[TModel]) -> TModel | None: - """解析 invoke_json_object 返回的 JSON 字符串。""" - from app.core.json_utils import extract_json_payload - - try: - cleaned = extract_json_payload(raw) - data = json.loads(cleaned) - return model.model_validate(data) - except (json.JSONDecodeError, ValueError, TypeError): - return None - - def facts_payload_to_dicts(payload: FactsExtractionPayload) -> list[dict]: out: list[dict] = [] for item in payload.facts: @@ -95,21 +49,3 @@ def facts_payload_to_dicts(payload: FactsExtractionPayload) -> list[dict]: def enrichment_payload_to_fact_dicts(payload: EnrichmentPayload) -> list[dict]: """将 EnrichmentPayload.facts 转为与 extract_facts 一致的字典列表。""" return facts_payload_to_dicts(FactsExtractionPayload(facts=list(payload.facts))) - - -def timeline_payload_to_dicts(payload: TimelineEventsPayload) -> list[dict]: - out: list[dict] = [] - for ev in payload.events: - title = (ev.title or "").strip() - if not title: - continue - out.append( - { - "event_year": ev.event_year, - "event_date": ev.event_date, - "title": title, - "description": ev.description, - "source_fact_ids": ev.source_fact_ids or [], - } - ) - return out[:20] diff --git a/api/app/features/memory/models.py b/api/app/features/memory/models.py index 1311a43..586c318 100644 --- a/api/app/features/memory/models.py +++ b/api/app/features/memory/models.py @@ -28,6 +28,10 @@ class MemorySource(Base): speaker = Column(String, nullable=True) captured_at = Column(DateTime(timezone=True), nullable=True) status = Column(String, default="active") + embedding_status = Column(String, default="pending") + embedding_error = Column(Text, nullable=True) + enrichment_status = Column(String, default="pending") + enrichment_error = Column(Text, nullable=True) conversation_id = Column(String, ForeignKey("conversations.id"), nullable=True) lineage_json = Column(JSON, nullable=True) primary_user_message_id = Column(String, nullable=True) @@ -52,6 +56,8 @@ class MemoryChunk(Base): event_year = Column(Integer, nullable=True) metadata_json = Column(JSON, nullable=True) is_excluded = Column(Boolean, default=False) + embedding_status = Column(String, default="pending") + embedding_error = Column(Text, nullable=True) created_at = Column(DateTime(timezone=True), default=utc_now) source = relationship("MemorySource", back_populates="chunks") diff --git a/api/app/features/memory/repo.py b/api/app/features/memory/repo.py index ace67f3..45c433f 100644 --- a/api/app/features/memory/repo.py +++ b/api/app/features/memory/repo.py @@ -1,10 +1,11 @@ -"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access.""" +"""Memory repository — MemorySource, MemoryChunk, and MemoryFact data access.""" import uuid from datetime import datetime, timedelta, timezone -from sqlalchemy import delete, literal, or_, select, text, tuple_, update +from sqlalchemy import cast, literal, or_, select, text, tuple_, update from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.types import String as SqlString from app.features.memory.models import ( MemoryChunk, @@ -12,7 +13,6 @@ from app.features.memory.models import ( MemoryFact, MemorySource, MemorySummary, - TimelineEvent, ) @@ -37,6 +37,8 @@ async def create_source( user_id=user_id, source_type=source_type, raw_text=raw_text, + embedding_status="pending", + enrichment_status="pending", conversation_id=conversation_id, lineage_json=lineage_json, primary_user_message_id=primary_user_message_id, @@ -61,6 +63,7 @@ async def create_chunk( user_id=user_id, content=content, chunk_index=chunk_index, + embedding_status="pending", ) db.add(chunk) return chunk @@ -73,6 +76,75 @@ async def update_chunk_embedding( chunk = await db.get(MemoryChunk, chunk_id) if chunk: chunk.embedding = embedding + chunk.embedding_status = "success" + chunk.embedding_error = None + + +async def set_chunk_embedding_status( + db: AsyncSession, + chunk_id: str, + *, + status: str, + error: str | None = None, +) -> bool: + chunk = await db.get(MemoryChunk, chunk_id) + if chunk is None: + return False + chunk.embedding_status = status + chunk.embedding_error = error + return True + + +async def set_source_embedding_status( + db: AsyncSession, + *, + source_id: str, + user_id: str, + status: str, + error: str | None = None, +) -> bool: + source = await db.get(MemorySource, source_id) + if source is None or source.user_id != user_id: + return False + source.embedding_status = status + source.embedding_error = error + return True + + +async def set_source_enrichment_status( + db: AsyncSession, + *, + source_id: str, + user_id: str, + status: str, + error: str | None = None, +) -> bool: + source = await db.get(MemorySource, source_id) + if source is None or source.user_id != user_id: + return False + source.enrichment_status = status + source.enrichment_error = error + return True + + +async def list_chunks_for_source( + db: AsyncSession, + *, + user_id: str, + source_id: str, + include_excluded: bool = True, +) -> list[MemoryChunk]: + stmt = ( + select(MemoryChunk) + .where(MemoryChunk.user_id == user_id, MemoryChunk.source_id == source_id) + .order_by(MemoryChunk.chunk_index.asc(), MemoryChunk.id.asc()) + ) + if not include_excluded: + stmt = stmt.where( + or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)) + ) + result = await db.execute(stmt) + return list(result.unique().scalars().all()) async def get_chunks_by_ids( @@ -114,7 +186,11 @@ async def search_facts_for_user_async( .where( MemoryFact.user_id == user_id, MemoryFact.status == "confirmed", - or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)), + or_( + MemoryFact.subject.ilike(pat), + MemoryFact.predicate.ilike(pat), + cast(MemoryFact.object_json, SqlString).ilike(pat), + ), ) .order_by(MemoryFact.created_at.desc()) .limit(limit) @@ -139,29 +215,6 @@ async def mark_facts_stale_for_excluded_chunk( return int(res.rowcount or 0) -async def search_timeline_events_for_user_async( - db: AsyncSession, user_id: str, query: str, limit: int = 20 -) -> list[TimelineEvent]: - q = (query or "").strip() - if not q: - return [] - pat = f"%{q}%" - stmt = ( - select(TimelineEvent) - .where( - TimelineEvent.user_id == user_id, - or_( - TimelineEvent.title.ilike(pat), - TimelineEvent.description.ilike(pat), - ), - ) - .order_by(TimelineEvent.event_year.desc().nullslast()) - .limit(limit) - ) - result = await db.execute(stmt) - return list(result.unique().scalars().all()) - - async def search_chunks_vector( db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20 ) -> list[dict]: @@ -207,22 +260,6 @@ async def list_users_with_recent_chunks(db: AsyncSession, *, hours: int) -> list return list(result.scalars().all()) -async def get_timeline_events_for_user( - db: AsyncSession, user_id: str, limit: int = 20 -) -> list[TimelineEvent]: - """Fetch timeline events for user.""" - stmt = ( - select(TimelineEvent) - .where(TimelineEvent.user_id == user_id) - .order_by( - TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc() - ) - .limit(limit) - ) - result = await db.execute(stmt) - return list(result.unique().scalars().all()) - - async def list_storage_keys_for_conversation( db: AsyncSession, conversation_id: str ) -> list[str]: @@ -302,46 +339,6 @@ async def set_memory_fact_status( return True -async def delete_timeline_events_by_memory_source( - db: AsyncSession, *, user_id: str, memory_source_id: str -) -> int: - stmt = delete(TimelineEvent).where( - TimelineEvent.user_id == user_id, - TimelineEvent.memory_source_id == memory_source_id, - ) - result = await db.execute(stmt) - return result.rowcount or 0 - - -async def create_timeline_event( - db: AsyncSession, - *, - user_id: str, - event_year: int | None, - event_date: str | None, - title: str, - description: str | None, - person_refs: list | None = None, - source_fact_ids: list[str] | None = None, - memory_source_id: str | None = None, - lineage_json: dict | None = None, -) -> TimelineEvent: - row = TimelineEvent( - id=_new_id(), - user_id=user_id, - memory_source_id=memory_source_id, - event_year=event_year, - event_date=event_date, - title=title, - description=description, - person_refs=person_refs, - source_fact_ids=source_fact_ids, - lineage_json=lineage_json, - ) - db.add(row) - return row - - async def create_curation_action( db: AsyncSession, *, diff --git a/api/app/features/memory/retrieval_service.py b/api/app/features/memory/retrieval_service.py index e78a863..5c2b7f7 100644 --- a/api/app/features/memory/retrieval_service.py +++ b/api/app/features/memory/retrieval_service.py @@ -38,14 +38,13 @@ class MemoryRetrievalService: vec_ok = self._embedding.is_available() if self._embedding else False logger.info( "event=memory_retrieve_done user_id={} query_len={} top_k={} " - "chunks={} facts={} summaries={} timeline={} stories={} vector_ok={}", + "chunks={} facts={} summaries={} stories={} vector_ok={}", user_id, len((query or "").strip()), top_k, len(bd.get("relevant_chunks") or []), len(bd.get("relevant_facts") or []), len(bd.get("relevant_summaries") or []), - len(bd.get("timeline_hints") or []), len(bd.get("relevant_stories") or []), vec_ok, ) diff --git a/api/app/features/memory/retrieval_trace.py b/api/app/features/memory/retrieval_trace.py index f53751a..d7983ca 100644 --- a/api/app/features/memory/retrieval_trace.py +++ b/api/app/features/memory/retrieval_trace.py @@ -29,7 +29,6 @@ def chat_memory_retrieval_trace_from_bundle( "query_len": query_len, "chunk_ids": _capped_ids(bundle.get("relevant_chunks")), "fact_ids": _capped_ids(bundle.get("relevant_facts")), - "timeline_event_ids": _capped_ids(bundle.get("timeline_hints")), "summary_ids": _capped_ids(bundle.get("relevant_summaries")), "story_ids": _capped_ids(bundle.get("relevant_stories")), } diff --git a/api/app/features/memory/retriever.py b/api/app/features/memory/retriever.py index f398049..7102eb2 100644 --- a/api/app/features/memory/retriever.py +++ b/api/app/features/memory/retriever.py @@ -11,7 +11,7 @@ logger = get_logger(__name__) class HybridRetriever: - """向量 chunk 检索 + facts/timeline/summaries/stories。""" + """向量 chunk 检索 + facts/summaries/stories。""" def __init__( self, @@ -25,7 +25,7 @@ class HybridRetriever: async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> dict: """ Return evidence bundle: - {relevant_chunks, relevant_summaries, relevant_facts, timeline_hints, relevant_stories} + {relevant_chunks, relevant_summaries, relevant_facts, relevant_stories} """ if not query.strip(): return await retrieve_evidence_bundle_async( diff --git a/api/app/features/memory/schemas.py b/api/app/features/memory/schemas.py index 0ec894f..8d8e37e 100644 --- a/api/app/features/memory/schemas.py +++ b/api/app/features/memory/schemas.py @@ -9,5 +9,4 @@ class EvidenceBundle(BaseModel): relevant_chunks: list[dict] = [] relevant_summaries: list[dict] = [] relevant_facts: list[dict] = [] - timeline_hints: list[dict] = [] relevant_stories: list[dict] = [] diff --git a/api/app/features/memory/service.py b/api/app/features/memory/service.py index 7726007..b1e140c 100644 --- a/api/app/features/memory/service.py +++ b/api/app/features/memory/service.py @@ -8,9 +8,11 @@ Celery task 只能作为同步入口包装 async service,不再维护 sync mem from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger +from app.features.memory.embedding_service import MemoryEmbeddingService from app.features.memory.ingest_service import MemoryIngestService from app.features.memory.repo import ( create_curation_action, + mark_facts_stale_for_excluded_chunk, set_chunk_excluded, set_memory_fact_status, ) @@ -78,7 +80,28 @@ class MemoryService: """Run post-ingest enrichment through the async memory path.""" from app.features.memory.enrichment import enrich_memory_after_ingest_async - await enrich_memory_after_ingest_async(self._db, user_id, source_id, llm=llm) + await enrich_memory_after_ingest_async( + self._db, + user_id, + source_id, + llm=llm, + raise_on_failure=True, + ) + + async def embed_source( + self, + user_id: str, + source_id: str, + *, + raise_on_failure: bool = False, + ) -> dict: + """Embed persisted memory chunks and update embedding status.""" + service = MemoryEmbeddingService(self._db, embedding_provider=self._embedding) + return await service.embed_source( + user_id, + source_id, + raise_on_failure=raise_on_failure, + ) async def compact_user(self, user_id: str, context: dict | None = None) -> dict: """Run near-duplicate compaction through the async memory path.""" @@ -92,13 +115,21 @@ class MemoryService: ok = await set_chunk_excluded(self._db, chunk_id, user_id, True) if not ok: return False + stale_count = await mark_facts_stale_for_excluded_chunk( + self._db, + user_id=user_id, + chunk_id=chunk_id, + ) await create_curation_action( self._db, user_id=user_id, action_type="exclude", target_type="chunk", target_id=chunk_id, - details={"reason": reason} if reason else None, + details={ + **({"reason": reason} if reason else {}), + "staled_fact_count": stale_count, + }, ) await self._db.commit() return True @@ -113,7 +144,7 @@ class MemoryService: action_type="restore", target_type="chunk", target_id=chunk_id, - details=None, + details={"fact_restore_policy": "requires_reenrichment"}, ) await self._db.commit() return True diff --git a/api/app/features/memory/summarizer.py b/api/app/features/memory/summarizer.py deleted file mode 100644 index d3f0bda..0000000 --- a/api/app/features/memory/summarizer.py +++ /dev/null @@ -1,82 +0,0 @@ -"""会话摘要与滚动摘要(async LLM + JSON)。""" - -from __future__ import annotations - -from typing import Any - -from app.core.langchain_llm import ainvoke_json_object -from app.core.logging import get_logger -from app.features.memory.llm_schemas import ( - RollingSummaryPayload, - SessionSummaryPayload, - parse_json_payload, -) - -logger = get_logger(__name__) - -_ROLLING_SUMMARY_MERGE_RULES_ZH = ( - "若新材料与已有摘要在同一人物或事件上存在明显事实冲突(如阵亡与在世、牺牲与退休、军衔或驻地去向矛盾)," - "以新材料为准,删除或改写旧摘要中的矛盾句;不得把两处矛盾信息拼接成一句。" - "不得将两则无因果关联的信息强行合成因果关系。" -) - - -def _max_input_chars() -> int: - from app.core.config import settings - - return settings.memory_enrichment_max_chars - - -async def generate_session_summary_async(llm: Any, chunk_texts: list[str]) -> str: - if not llm: - return "" - lim = _max_input_chars() - combined = "\n\n".join(t for t in chunk_texts if t).strip()[:lim] - if not combined: - return "" - prompt = ( - "用 2~8 句中文概括下列口述/对话要点,不编造、不评价。只输出 JSON:" - '{"summary":"..."}\n\n文本:\n' - f"{combined}" - ) - try: - raw = await ainvoke_json_object( - llm, prompt, max_tokens=2048, agent="memory.session_summary_async" - ) - parsed = parse_json_payload(raw, SessionSummaryPayload) - if parsed is None: - return "" - return str(parsed.summary or "").strip() - except (TypeError, ValueError) as e: - logger.warning("generate_session_summary_async 失败: {}", e) - return "" - - -async def generate_rolling_summary_async( - llm: Any, existing_summary: str | None, new_chunk_texts: list[str] -) -> str: - if not llm: - return (existing_summary or "").strip() - lim = _max_input_chars() - new_t = "\n\n".join(t for t in new_chunk_texts if t).strip()[:lim] - if not new_t and not (existing_summary or "").strip(): - return "" - ex = (existing_summary or "").strip()[:lim] - prompt = ( - "将「已有滚动摘要」与「新材料」合并为更新后的滚动摘要(中文,段落)。" - "保留人物与时间线索;不编造。\n" - f"{_ROLLING_SUMMARY_MERGE_RULES_ZH}\n" - '只输出 JSON:{"rolling_summary":"..."}\n\n' - f"【已有摘要】\n{ex}\n\n【新材料】\n{new_t}" - ) - try: - raw = await ainvoke_json_object( - llm, prompt, max_tokens=3072, agent="memory.rolling_summary_async" - ) - parsed = parse_json_payload(raw, RollingSummaryPayload) - if parsed is None: - return (existing_summary or "").strip() - return str(parsed.rolling_summary or "").strip() - except (TypeError, ValueError) as e: - logger.warning("generate_rolling_summary_async 失败: {}", e) - return (existing_summary or "").strip() diff --git a/api/app/features/memory/timeline.py b/api/app/features/memory/timeline.py deleted file mode 100644 index 7d4e29a..0000000 --- a/api/app/features/memory/timeline.py +++ /dev/null @@ -1,52 +0,0 @@ -"""由已抽取事实生成时间线事件(async LLM + JSON)。""" - -from __future__ import annotations - -import json -from typing import Any - -from app.core.langchain_llm import ainvoke_json_object -from app.core.llm_gateway import LlmGateway, LlmUseCase -from app.core.logging import get_logger -from app.features.memory.llm_schemas import ( - TimelineEventsPayload, - parse_json_payload, - timeline_payload_to_dicts, -) - -logger = get_logger(__name__) - -MAX_FACTS_JSON = 20000 - - -async def build_timeline_events_from_facts_async( - llm: Any, facts: list[dict] -) -> list[dict]: - if not llm or not facts: - return [] - payload = json.dumps(facts, ensure_ascii=False)[:MAX_FACTS_JSON] - prompt = ( - "根据下列事实(含 id)生成时间线事件。\n" - "每条含 event_year、event_date、title、description、source_fact_ids(来自输入 id)。\n" - '只输出 JSON:{"events":[...]}。\n\n' - f"{payload}" - ) - try: - raw = await ainvoke_json_object( - llm, prompt, max_tokens=4096, agent="memory.timeline_events_async" - ) - parsed = parse_json_payload(raw, TimelineEventsPayload) - if parsed is None: - return [] - return timeline_payload_to_dicts(parsed) - except (TypeError, ValueError) as e: - logger.warning("build_timeline_events_from_facts_async 失败: {}", e) - return [] - - -async def build_timeline_events(facts: list[dict]) -> list[dict]: - """兼容旧接口。""" - llm = LlmGateway().langchain_llm_for( - LlmUseCase("memory.timeline_events.compat", fast=True) - ) - return await build_timeline_events_from_facts_async(llm, facts) diff --git a/api/app/features/user/router.py b/api/app/features/user/router.py index 2532c11..26a1eaa 100644 --- a/api/app/features/user/router.py +++ b/api/app/features/user/router.py @@ -66,6 +66,11 @@ async def update_user_profile( current_user: User = Depends(get_current_user), service: UserService = Depends(get_user_service), ): + logger.info( + "更新用户档案 user_id={} fields={}", + current_user.id, + sorted(body.model_fields_set), + ) return await service.update_profile(current_user.id, body) diff --git a/api/app/features/user/service.py b/api/app/features/user/service.py index a1c350f..2681df2 100644 --- a/api/app/features/user/service.py +++ b/api/app/features/user/service.py @@ -46,14 +46,9 @@ class UserService: user = await repo.get_user_by_id(user_id, self._db) if not user: raise ValueError("用户不存在") - if body.birth_year is not None: - user.birth_year = body.birth_year - if body.birth_place is not None: - user.birth_place = body.birth_place - if body.grew_up_place is not None: - user.grew_up_place = body.grew_up_place - if body.occupation is not None: - user.occupation = body.occupation + for field in ("birth_year", "birth_place", "grew_up_place", "occupation"): + if field in body.model_fields_set: + setattr(user, field, getattr(body, field)) await self._db.commit() await self._db.refresh(user) return _user_to_profile(user) diff --git a/api/app/tasks/__init__.py b/api/app/tasks/__init__.py index efeae56..324da23 100644 --- a/api/app/tasks/__init__.py +++ b/api/app/tasks/__init__.py @@ -7,7 +7,6 @@ from .chapter_cover_tasks import generate_chapter_cover from .memoir_tasks import ( process_memoir_phase1, process_memoir_phase2, - process_memoir_segments, ) from .memory_compaction_tasks import memory_compaction_run from .story_image_tasks import generate_story_image @@ -16,7 +15,6 @@ __all__ = [ "celery_app", "process_memoir_phase1", "process_memoir_phase2", - "process_memoir_segments", "generate_chapter_cover", "generate_story_image", "memory_compaction_run", diff --git a/api/app/tasks/celery_app.py b/api/app/tasks/celery_app.py index 20cdff8..a8a8ad3 100644 --- a/api/app/tasks/celery_app.py +++ b/api/app/tasks/celery_app.py @@ -72,6 +72,9 @@ celery_app.conf.update( task_acks_late=True, # 任务完成后再确认 task_reject_on_worker_lost=True, # worker 丢失时拒绝任务 task_routes={ + "app.tasks.memory_enrichment_tasks.embed_memory_source": { + "queue": settings.celery_memory_enrichment_queue, + }, "app.tasks.memory_enrichment_tasks.enrich_memory_source": { "queue": settings.celery_memory_enrichment_queue, }, @@ -79,6 +82,10 @@ celery_app.conf.update( ) celery_app.conf.task_annotations = { + "app.tasks.memory_enrichment_tasks.embed_memory_source": { + "soft_time_limit": 660, + "time_limit": 960, + }, "app.tasks.memory_enrichment_tasks.enrich_memory_source": { "soft_time_limit": 660, "time_limit": 960, diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index b2434bd..440d08c 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -6,7 +6,7 @@ import asyncio import json import time import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Dict, List, Set import redis @@ -336,6 +336,133 @@ def _phase2_immediate_task_id(user_id: str, chapter_category: str) -> str: return f"phase2-immediate-{user_id}-{chapter_category}" +def _wake_deferred_segments_for_category( + db: Session, + user_id: str, + chapter_category: str, +) -> int: + """清空该用户某 chapter_category 下旧的 defer 元数据,让其与新素材一起重判。 + + 返回被唤醒的 segment 数量,仅用于日志。 + """ + user_convs = select(Conversation.id).where( + Conversation.user_id == user_id, + Conversation.deleted_at.is_(None), + ) + stmt = select(Segment).where( + Segment.conversation_id.in_(user_convs), + Segment.topic_category == chapter_category, + Segment.narrated.is_(False), + Segment.skip_narrative.is_(False), + Segment.narrative_deferred_until.isnot(None), + ) + rows = list(db.execute(stmt).scalars().all()) + if not rows: + return 0 + for seg in rows: + seg.narrative_deferred_until = None + seg.narrative_defer_count = 0 + seg.narrative_defer_reason = None + return len(rows) + + +def _persist_phase2_route_defer( + db: Session, + *, + user_id: str, + chapter_category: str, + task_id: str, + memoir_correlation_id: str | None, + defer_segment_ids: list[str], + defer_reason: str, + phase2_started: float, + pipeline_elapsed: float, + lock_elapsed: float, +) -> dict: + """把本批 segment 标记为延迟态,并按需再排一次 Phase2 timeout。 + + 返回 Celery 任务的 result dict(``status=deferred``)。 + """ + now_ts = datetime.now(timezone.utc) + max_attempts = int(settings.memoir_route_defer_max_attempts) + defer_seconds = float(settings.memoir_route_defer_seconds) + deferred_until_ts = now_ts + timedelta(seconds=max(defer_seconds, 1.0)) + + rows: list[Segment] = [] + if defer_segment_ids: + stmt = select(Segment).where(Segment.id.in_(list(defer_segment_ids))) + rows = list(db.execute(stmt).scalars().all()) + + saturated_segments = 0 + new_max_attempts_reached = False + for seg in rows: + prev_count = int(seg.narrative_defer_count or 0) + seg.narrative_defer_count = prev_count + 1 + seg.narrative_defer_reason = defer_reason + seg.narrative_last_attempt_at = now_ts + if seg.narrative_defer_count >= max_attempts: + seg.narrative_deferred_until = None + saturated_segments += 1 + new_max_attempts_reached = True + else: + seg.narrative_deferred_until = deferred_until_ts + + db.commit() + + next_task_id: str | None = None + if rows and not new_max_attempts_reached: + next_task_id = _schedule_phase2_timeout( + user_id, chapter_category, memoir_correlation_id + ) + + phase2_elapsed = time.perf_counter() - phase2_started + duration_ms = phase2_elapsed * 1000 + logger.info( + "event=memoir_phase2_route_deferred user_id={} task_id={} chapter_category={} " + "segment_count={} saturated_count={} reason={} memoir_correlation_id={} " + "lock_seconds={:.3f} pipeline_seconds={:.3f} " + "phase2_total_seconds={:.3f} duration_ms={:.1f} next_task_id={} " + "msg=Phase2 路由低置信,本批 segment 延迟", + user_id, + task_id, + chapter_category, + len(rows), + saturated_segments, + defer_reason, + memoir_correlation_id or "", + lock_elapsed, + pipeline_elapsed, + phase2_elapsed, + duration_ms, + next_task_id or "", + ) + merge_pipeline_run( + memoir_correlation_id, + { + "phase2": [ + { + "chapter_category": chapter_category, + "task_id": str(task_id), + "status": "deferred", + "detail": { + "segments": len(rows), + "reason": defer_reason, + "saturated_count": saturated_segments, + "next_task_id": next_task_id, + }, + } + ], + }, + ) + return { + "status": "deferred", + "chapter_category": chapter_category, + "segments": len(rows), + "reason": defer_reason, + "saturated_count": saturated_segments, + } + + def _schedule_phase2_timeout( user_id: str, chapter_category: str, memoir_correlation_id: str | None = None ) -> str | None: @@ -492,6 +619,7 @@ def process_memoir_phase2( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) + now_utc = datetime.now(timezone.utc) stmt = ( select(Segment) .where( @@ -499,6 +627,10 @@ def process_memoir_phase2( Segment.topic_category == chapter_category, Segment.narrated.is_(False), Segment.skip_narrative.is_(False), + ( + Segment.narrative_deferred_until.is_(None) + | (Segment.narrative_deferred_until <= now_utc) + ), ) .order_by(Segment.created_at) ) @@ -603,11 +735,10 @@ def process_memoir_phase2( "relevant_chunks": [], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } pipeline_t0 = time.perf_counter() - chapter, needs_cover, disp = run_story_pipeline_for_category_batch( + pipeline_result = run_story_pipeline_for_category_batch( db, user_id=user_id, chapter_category=chapter_category, @@ -623,7 +754,24 @@ def process_memoir_phase2( memory_evidence=memory_evidence, ) pipeline_elapsed = time.perf_counter() - pipeline_t0 - story_dispatch_ids |= disp + + if pipeline_result.deferred: + deferred_response = _persist_phase2_route_defer( + db, + user_id=user_id, + chapter_category=chapter_category, + task_id=str(task_id), + memoir_correlation_id=cid, + defer_segment_ids=pipeline_result.defer_segment_ids, + defer_reason=pipeline_result.defer_reason or "unknown", + phase2_started=phase2_t0, + pipeline_elapsed=pipeline_elapsed, + lock_elapsed=lock_elapsed, + ) + return deferred_response + + chapter = pipeline_result.chapter + story_dispatch_ids |= pipeline_result.dispatch_ids db.flush() if chapter is None: logger.error( @@ -949,6 +1097,7 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): categories_for_phase2: Set[str] = set() phase2_immediate: list[str] = [] phase2_timeout: list[str] = [] + woke_up_by_category: dict[str, int] = {} for chapter_category, cat_segments in prepared.category_to_segments.items(): batch_non_skip = [ s @@ -957,6 +1106,11 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): ] if not batch_non_skip: continue + woke = _wake_deferred_segments_for_category( + db, user_id, chapter_category + ) + if woke: + woke_up_by_category[chapter_category] = woke max_chars = max( len((s.user_input_text or "").strip()) for s in batch_non_skip ) @@ -966,6 +1120,14 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): else: phase2_timeout.append(chapter_category) + if woke_up_by_category: + logger.info( + "event=memoir_phase1_wake_deferred user_id={} categories={} " + "msg=Phase1 新素材唤醒同类目延迟 segment", + user_id, + woke_up_by_category, + ) + db.commit() merge_pipeline_run( @@ -1081,11 +1243,6 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) raise self.retry(exc=e) from e - -# 兼容旧 Celery/文档入口名 -process_memoir_segments = process_memoir_phase1 - - @shared_task(bind=True, max_retries=3, default_retry_delay=30) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ diff --git a/api/app/tasks/memory_enrichment_tasks.py b/api/app/tasks/memory_enrichment_tasks.py index bcf9542..54434ef 100644 --- a/api/app/tasks/memory_enrichment_tasks.py +++ b/api/app/tasks/memory_enrichment_tasks.py @@ -1,6 +1,5 @@ """ -Memory enrichment Celery task — runs asynchronously after ingest to generate -summaries, facts, and timeline events without blocking ingest or memoir pipeline. +Memory pipeline Celery tasks — retry embedding and enrichment after durable ingest. Tasks are routed to ``settings.celery_memory_enrichment_queue`` (default ``memory_idle``); run workers with ``-Q celery,memory_idle`` or a dedicated low-priority worker for that queue. @@ -8,11 +7,13 @@ run workers with ``-Q celery,memory_idle`` or a dedicated low-priority worker fo import asyncio import time +from typing import Any, cast from celery import shared_task from app.core.config import settings from app.core.db import AsyncSessionLocal +from app.core.dependencies import get_embedding_provider from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_fanout_item from app.features.memory.service import MemoryService @@ -30,6 +31,65 @@ async def _enrich_memory_source_async( await db.commit() +async def _embed_memory_source_async( + user_id: str, + source_id: str, +) -> dict: + async with AsyncSessionLocal() as db: + service = MemoryService(db, embedding_provider=get_embedding_provider()) + result = await service.embed_source( + user_id, + source_id, + raise_on_failure=True, + ) + await db.commit() + return result + + +def schedule_memory_embedding( + user_id: str, + source_id: str, + *, + memoir_correlation_id: str | None = None, +) -> str | None: + """Enqueue embedding retry for a persisted memory source.""" + uid = (user_id or "").strip() + sid = (source_id or "").strip() + if not uid or not sid: + return None + q = (settings.celery_memory_enrichment_queue or "").strip() or "memory_idle" + try: + task = cast(Any, embed_memory_source) + ar = task.apply_async( + args=[uid, sid], + kwargs={"memoir_correlation_id": memoir_correlation_id}, + queue=q, + ) + emb_id = getattr(ar, "id", None) + if not emb_id: + return None + cid = (memoir_correlation_id or "").strip() + if cid: + merge_fanout_item( + cid, + list_name="memory_embedding", + id_field="source_id", + item_id=sid, + task_id=str(emb_id), + status="enqueued", + ) + return str(emb_id) + except Exception as e: + logger.warning( + "event=memory_embedding_schedule_failed user_id={} source_id={} exc={} exc_type={}", + uid, + sid, + e, + type(e).__name__, + ) + return None + + def schedule_memory_enrichment( user_id: str, source_id: str, @@ -50,7 +110,8 @@ def schedule_memory_enrichment( return None q = (settings.celery_memory_enrichment_queue or "").strip() or "memory_idle" try: - ar = enrich_memory_source.apply_async( + task = cast(Any, enrich_memory_source) + ar = task.apply_async( args=[uid, sid], kwargs={"memoir_correlation_id": memoir_correlation_id}, queue=q, @@ -80,6 +141,74 @@ def schedule_memory_enrichment( return None +@shared_task(bind=True, max_retries=3, default_retry_delay=30) +def embed_memory_source( + self, + user_id: str, + source_id: str, + memoir_correlation_id: str | None = None, +): + """Post-ingest embedding retry for persisted chunks.""" + tid = str(self.request.id) + t0 = time.perf_counter() + logger.info( + "event=memory_embedding_start user_id={} source_id={} task_id={} msg=开始记忆向量化", + user_id, + source_id, + tid, + ) + merge_fanout_item( + memoir_correlation_id, + list_name="memory_embedding", + id_field="source_id", + item_id=source_id, + task_id=tid, + status="running", + ) + try: + result = asyncio.run(_embed_memory_source_async(user_id, source_id)) + ms = (time.perf_counter() - t0) * 1000 + logger.info( + "event=memory_embedding_done user_id={} source_id={} duration_ms={:.1f} status={} vectors_written={} msg=记忆向量化完成", + user_id, + source_id, + ms, + result.get("status"), + result.get("vectors_written", 0), + ) + merge_fanout_item( + memoir_correlation_id, + list_name="memory_embedding", + id_field="source_id", + item_id=source_id, + task_id=tid, + status="success", + extra=result, + ) + return {"source_id": source_id, **result} + except Exception as e: + ms = (time.perf_counter() - t0) * 1000 + logger.warning( + "event=memory_embedding_failed user_id={} source_id={} duration_ms={:.1f} " + "exc={} exc_type={} msg=记忆向量化失败", + user_id, + source_id, + ms, + e, + type(e).__name__, + ) + merge_fanout_item( + memoir_correlation_id, + list_name="memory_embedding", + id_field="source_id", + item_id=source_id, + task_id=tid, + status="failure", + extra={"error": str(e)}, + ) + raise self.retry(exc=e) from e + + @shared_task(bind=True, max_retries=2, default_retry_delay=30) def enrich_memory_source( self, diff --git a/api/development.sh b/api/development.sh index 6bd1e8a..e99bee7 100755 --- a/api/development.sh +++ b/api/development.sh @@ -168,6 +168,7 @@ start_infra() { cd "${ROOT_DIR}" docker compose -f docker-compose.dev.yml up -d INFRA_STARTED=1 + print_ok "PostgreSQL 127.0.0.1:48291,Redis 127.0.0.1:48307(见 docker-compose.dev.yml / .env.example)" print_ok "基础设施已就绪" } @@ -296,7 +297,7 @@ print_alembic_failure_hint() { log_output="$(sed -n '1,200p' "${log_file}")" if [[ "${log_output}" == *'could not translate host name "postgres"'* ]] || [[ "${log_output}" == *"Name or service not known"* ]]; then - print_warn "看起来 DATABASE_URL 指向了容器内主机名;在宿主机运行时请改用 localhost:5432" + print_warn "看起来 DATABASE_URL 指向了容器内主机名;在宿主机运行时请改用 localhost:48291(见 docker-compose.dev.yml)" elif [[ "${log_output}" == *"Connection refused"* ]] || [[ "${log_output}" == *"could not connect to server"* ]]; then print_warn "PostgreSQL 连接被拒绝;请确认容器已启动且 DATABASE_URL 与 docker-compose.dev.yml 暴露端口一致" elif [[ "${log_output}" == *"password authentication failed"* ]]; then diff --git a/api/docker-compose.dev.yml b/api/docker-compose.dev.yml index b449830..c474fcc 100644 --- a/api/docker-compose.dev.yml +++ b/api/docker-compose.dev.yml @@ -1,5 +1,9 @@ # 开发环境 Docker Compose # 使用方法: docker compose -f docker-compose.dev.yml up -d +# +# 宿主端口为项目约定的固定高位端口(避免与本机常用 5432/6379 冲突),与本仓库 .env.example 对齐: +# PostgreSQL 127.0.0.1:48291 → 容器 5432 +# Redis 127.0.0.1:48307 → 容器 6379 services: # PostgreSQL 数据库(pg17 + pgvector,memory 模块需要 vector 类型) diff --git a/api/docker-compose.yml b/api/docker-compose.yml index b3e6368..0e13a42 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -4,7 +4,7 @@ services: image: m.daocloud.io/docker.io/pgvector/pgvector:pg17 container_name: life-echo-postgres ports: - - "127.0.0.1:5432:5432" # 仅绑定 localhost,通过 SSH 隧道访问 + - "127.0.0.1:${POSTGRES_HOST_PORT:-5432}:5432" # 仅绑定 localhost,通过 SSH 隧道访问 environment: POSTGRES_USER: ${POSTGRES_USER:-postgres} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres} @@ -56,10 +56,10 @@ services: dockerfile: Dockerfile image: life-echo-api:latest container_name: life-echo-api-prod - # 独立 Caddy(宿主机或其它 compose)经 HTTPS 反代;仅绑定本机回环,避免与机上其它项目端口直接对公网。 - # 若与 Cosmetic 等共用主机且 8000 已被占用,在 .env 中设置 LIFE_ECHO_API_HOST_PORT=其它端口并在 Caddyfile 中一致。 + # 默认仅绑定本机回环,交给宿主机 Caddy/反代;staging 如需 IP:port 直连,可在 .env 设置 LIFE_ECHO_API_HOST_BIND=0.0.0.0。 + # 若与 Cosmetic 等共用主机且 8000 已被占用,在 .env 中设置 LIFE_ECHO_API_HOST_PORT=其它端口并在 Caddyfile / app env 中一致。 ports: - - "127.0.0.1:${LIFE_ECHO_API_HOST_PORT:-8000}:8000" + - "${LIFE_ECHO_API_HOST_BIND:-127.0.0.1}:${LIFE_ECHO_API_HOST_PORT:-8000}:8000" env_file: - .env environment: diff --git a/api/docs/ai-touchpoints.md b/api/docs/ai-touchpoints.md index 3cecf96..3b1bbd5 100644 --- a/api/docs/ai-touchpoints.md +++ b/api/docs/ai-touchpoints.md @@ -56,18 +56,17 @@ Regenerate: `uv run python api/scripts/ai_touchpoints_scan.py --markdown api/doc | `api/app/features/memoir/story_pipeline_sync.py` | `agents_layer`, `embedding` | | `api/app/features/memory/curation.py` | `memory_ai` | | `api/app/features/memory/deps.py` | `embedding`, `memory_ai` | +| `api/app/features/memory/embedding_scheduler.py` | `memory_ai` | +| `api/app/features/memory/embedding_service.py` | `embedding`, `memory_ai` | | `api/app/features/memory/enrichment.py` | `json_llm_helpers`, `langchain`, `llm_provider`, `memory_ai` | | `api/app/features/memory/evidence.py` | `embedding`, `memory_ai`, `ports_ai` | | `api/app/features/memory/evidence_format.py` | `memory_ai` | -| `api/app/features/memory/extractor.py` | `json_llm_helpers`, `langchain`, `llm_provider` | | `api/app/features/memory/llm_schemas.py` | `json_llm_helpers` | | `api/app/features/memory/repo.py` | `embedding`, `memory_ai`, `ports_ai` | | `api/app/features/memory/retriever.py` | `embedding`, `memory_ai`, `ports_ai` | | `api/app/features/memory/router.py` | `memory_ai` | | `api/app/features/memory/schemas.py` | `memory_ai` | | `api/app/features/memory/service.py` | `embedding`, `memory_ai`, `ports_ai` | -| `api/app/features/memory/summarizer.py` | `json_llm_helpers`, `langchain` | -| `api/app/features/memory/timeline.py` | `json_llm_helpers`, `langchain`, `llm_provider` | | `api/app/ports/embedding.py` | `embedding` | | `api/app/ports/llm.py` | `ports_ai` | | `api/app/tasks/chapter_cover_tasks.py` | `agents_layer` | diff --git a/api/docs/memory-retrieval.md b/api/docs/memory-retrieval.md index 51a3a2f..7528ab5 100644 --- a/api/docs/memory-retrieval.md +++ b/api/docs/memory-retrieval.md @@ -4,15 +4,16 @@ Memory 运行链路只有一个入口:`MemoryService`。 | 能力 | 入口 | 行为 | | --- | --- | --- | -| ingest | `MemoryService.ingest_transcript` / `ingest_transcripts_batch` | 写入 `memory_sources`、`memory_chunks`、embedding;commit 后投递 enrichment | -| retrieve | `MemoryService.retrieve` | 非空 query 做向量 chunk 检索,并合并 query 命中的 facts / timeline / session summaries / stories | +| ingest | `MemoryService.ingest_transcript` / `ingest_transcripts_batch` | 先持久化 `memory_sources`、`memory_chunks`;随后写 embedding 状态并投递 enrichment | +| embed | `MemoryService.embed_source` | 对已持久化 chunks 生成向量;失败记录状态并由 Celery 重试 | +| retrieve | `MemoryService.retrieve` | 非空 query 做向量 chunk 检索,并合并 query 命中的 facts / session summaries / stories | | enrichment | `MemoryService.enrich_source` | 单次 LLM 生成 session summary 与 confirmed facts | | compaction | `MemoryService.compact_user` | 近重复 chunk 软排除并 stale 相关 facts | ## 检索语义 - 空 query 固定返回空 evidence bundle。 -- facts / timeline / summaries 只按 query 命中返回;不回退最近事实、最近时间线或 rolling summary。 +- facts / summaries 只按 query 命中返回;不回退最近事实或 rolling summary。 - `MemorySummary.summary_type="session"` 可进入 evidence;rolling summary 不参与 prompt evidence。 - Celery task 只是同步入口包装 async service,不再维护 sync memory 业务链路。 diff --git a/api/docs/本地开发环境配置.md b/api/docs/本地开发环境配置.md index 4655865..2082b89 100644 --- a/api/docs/本地开发环境配置.md +++ b/api/docs/本地开发环境配置.md @@ -24,15 +24,20 @@ ## 快速开始 -### 1. 启动 Redis +### 1. 启动 PostgreSQL / Redis -使用 Docker Compose 启动 Redis: +使用开发用 Docker Compose 一键启动数据库与缓存: ```bash cd api docker compose -f docker-compose.dev.yml up -d ``` +开发 compose 使用 **固定的** 本机映射(与 `api/.env.example` 一致,避免与本机默认 5432 / 6379 抢占): + +- PostgreSQL:`127.0.0.1:48291` → 容器内 `5432` +- Redis:`127.0.0.1:48307` → 容器内 `6379` + 验证 Redis 是否运行: ```bash @@ -61,12 +66,12 @@ DEEPSEEK_BASE_URL=https://api.deepseek.com # LLM_MODEL=gpt-4 # LLM_BASE_URL=https://api.openai.com -# Redis 配置 -REDIS_URL=redis://localhost:6379/0 +# Redis 配置(宿主 48307,见 docker-compose.dev.yml) +REDIS_URL=redis://localhost:48307/0 REDIS_SESSION_TTL=86400 # 会话过期时间(秒),默认 24 小时 -# 数据库配置(PostgreSQL,与线上一致) -DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo +# 数据库配置(宿主 48291,见 docker-compose.dev.yml) +DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo # JWT 配置 SECRET_KEY=your-secret-key-change-in-production @@ -114,7 +119,7 @@ celery -A tasks.celery_app worker --loglevel=info --concurrency=2 - 对话的实时响应通过异步 LLM 调用生成 - 会话历史存储在 Redis 中 -### Redis (端口 6379) +### Redis(容器内 6379 → 宿主 48307,见 docker-compose.dev.yml) - 存储对话会话历史(支持多实例部署) - 作为 Celery 的消息队列 @@ -169,12 +174,12 @@ docker compose up -d --scale celery-worker=3 ### Redis 连接失败 ``` -Redis 连接失败: Error connecting to redis://localhost:6379/0 +Redis 连接失败: Error connecting to redis://localhost:48307/0 ``` **解决方法**: 1. 确认 Redis 容器正在运行:`docker ps | grep redis` -2. 检查 `REDIS_URL` 环境变量是否正确 +2. 检查 `REDIS_URL` 是否为 `redis://localhost:48307/0`(或与 `docker-compose.dev.yml` 中映射一致) 3. 如果在 Docker 内运行 API,使用 `redis://redis:6379/0` ### Celery 任务不执行 @@ -251,9 +256,9 @@ asyncio.run(test()) ### 手动触发 Celery 任务 ```python -from app.tasks.memoir_tasks import process_memoir_segments +from app.tasks.memoir_tasks import process_memoir_phase1 # 同步调用(测试) -result = process_memoir_segments.delay("user_id", ["segment_id_1", "segment_id_2"]) +result = process_memoir_phase1.delay("user_id", ["segment_id_1", "segment_id_2"]) print(result.get(timeout=60)) ``` diff --git a/api/tests/test_avatar_preset_http.py b/api/tests/test_avatar_preset_http.py new file mode 100644 index 0000000..e9d2f67 --- /dev/null +++ b/api/tests/test_avatar_preset_http.py @@ -0,0 +1,170 @@ +"""预设头像 HTTP 契约。""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.core.dependencies import get_current_user +from app.features.auth.deps import get_auth_service +from app.features.auth.router import router as auth_router +from app.features.auth.service import AuthService +from app.features.user.models import User + + +def _mock_current_user() -> User: + u = MagicMock(spec=User) + u.id = str(uuid.uuid4()) + u.phone = "13800000000" + u.email = None + u.nickname = "测试用户" + u.avatar_url = None + u.subscription_type = "free" + u.created_at = datetime.now(timezone.utc) + return u + + +@pytest.fixture +def preset_auth_app() -> FastAPI: + app = FastAPI() + app.include_router(auth_router) + + fixed_user = _mock_current_user() + + async def _fake_update_avatar(uid: str, url: str): + fixed_user.avatar_url = url + return fixed_user + + mock_service = MagicMock(spec=AuthService) + mock_service.update_avatar_url = AsyncMock(side_effect=_fake_update_avatar) + + app.dependency_overrides[get_auth_service] = lambda: mock_service + app.dependency_overrides[get_current_user] = lambda: fixed_user + + app.state._mock_auth_service = mock_service + app.state._fixed_user = fixed_user + return app + + +@pytest.mark.asyncio +async def test_list_avatar_presets(preset_auth_app: FastAPI) -> None: + transport = ASGITransport(app=preset_auth_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatar-presets") + assert r.status_code == 200 + items = r.json() + assert len(items) == 8 + assert items[0]["id"] == "01" + assert items[0]["url"] == "/api/auth/avatar-presets/01.png" + + +@pytest.mark.asyncio +async def test_get_avatar_preset_ok(preset_auth_app: FastAPI) -> None: + transport = ASGITransport(app=preset_auth_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatar-presets/01.png") + assert r.status_code == 200 + assert r.headers.get("content-type", "").startswith("image/png") + + +@pytest.mark.asyncio +async def test_get_avatar_preset_unknown(preset_auth_app: FastAPI) -> None: + transport = ASGITransport(app=preset_auth_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatar-presets/99.png") + assert r.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_avatar_preset_path_traversal(preset_auth_app: FastAPI) -> None: + transport = ASGITransport(app=preset_auth_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatar-presets/../secrets.env") + assert r.status_code == 404 + + +@pytest.mark.asyncio +async def test_set_avatar_preset_ok(preset_auth_app: FastAPI) -> None: + uid = preset_auth_app.state._fixed_user.id + transport = ASGITransport(app=preset_auth_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.put( + "/api/auth/me/avatar/preset", + json={"preset_id": "02"}, + headers={"Authorization": "Bearer x"}, + ) + assert r.status_code == 200 + body = r.json() + assert body["avatar_url"] == "/api/auth/avatar-presets/02.png" + svc: MagicMock = preset_auth_app.state._mock_auth_service + svc.update_avatar_url.assert_awaited_once_with( + uid, "/api/auth/avatar-presets/02.png" + ) + + +@pytest.mark.asyncio +async def test_set_avatar_preset_invalid(preset_auth_app: FastAPI) -> None: + transport = ASGITransport(app=preset_auth_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.put( + "/api/auth/me/avatar/preset", + json={"preset_id": "99"}, + headers={"Authorization": "Bearer x"}, + ) + assert r.status_code == 400 + + +@pytest.mark.asyncio +async def test_get_uploaded_avatar_rejects_traversal() -> None: + app = FastAPI() + app.include_router(auth_router) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatars/../../../etc/passwd") + assert r.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_uploaded_avatar_ok_with_safe_file( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import app.features.auth.router as auth_router_mod + + avatar_dir = tmp_path / "avatars" + avatar_dir.mkdir() + (avatar_dir / "abc-def-123.jpg").write_bytes(b"x") + + monkeypatch.setattr(auth_router_mod, "AVATAR_DIR", avatar_dir) + + app = FastAPI() + app.include_router(auth_router) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatars/abc-def-123.jpg") + assert r.status_code == 200 + + +@pytest.mark.asyncio +async def test_get_uploaded_avatar_ok_when_stem_has_underscore( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import app.features.auth.router as auth_router_mod + + avatar_dir = tmp_path / "avatars" + avatar_dir.mkdir() + (avatar_dir / "user_abc_01.jpg").write_bytes(b"x") + + monkeypatch.setattr(auth_router_mod, "AVATAR_DIR", avatar_dir) + + app = FastAPI() + app.include_router(auth_router) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + r = await ac.get("/api/auth/avatars/user_abc_01.jpg") + assert r.status_code == 200 diff --git a/api/tests/test_dialogue_lineage_memory_ingest.py b/api/tests/test_dialogue_lineage_memory_ingest.py index 144ad36..c303c0a 100644 --- a/api/tests/test_dialogue_lineage_memory_ingest.py +++ b/api/tests/test_dialogue_lineage_memory_ingest.py @@ -27,6 +27,19 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None: captured["scheduled"] = request return "task-1" + class FakeEmbeddingScheduler: + def schedule(self, request): + captured["embedding_scheduled"] = request + return "embedding-task-1" + + class FakeEmbeddingService: + def __init__(self, *_args, **_kwargs) -> None: + pass + + async def embed_source(self, user_id: str, source_id: str) -> dict: + captured["embedded"] = (user_id, source_id) + return {"status": "success", "vectors_written": 1} + async def fake_create_source(session, **kwargs): captured.update(kwargs) return SimpleNamespace(id="src-1") @@ -42,6 +55,10 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None: "app.features.memory.ingest_service.create_chunk", fake_create_chunk, ) + monkeypatch.setattr( + "app.features.memory.ingest_service.MemoryEmbeddingService", + FakeEmbeddingService, + ) monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", False) lineage = { @@ -57,6 +74,7 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None: service = MemoryIngestService( fake_session, # type: ignore[arg-type] embedding_provider=None, + embedding_scheduler=FakeEmbeddingScheduler(), # type: ignore[arg-type] enrichment_scheduler=FakeScheduler(), # type: ignore[arg-type] ) sid = await service.ingest_transcript( diff --git a/api/tests/test_image_prompt_policy.py b/api/tests/test_image_prompt_policy.py index e337c11..005e103 100644 --- a/api/tests/test_image_prompt_policy.py +++ b/api/tests/test_image_prompt_policy.py @@ -121,3 +121,44 @@ def test_cover_fallback_disabled_requires_excerpt(monkeypatch): chapter_category="family", context_excerpt="", ) + + +def test_image_prompt_orchestrator_provider_failure_uses_fallback(monkeypatch): + from app.agents.image_prompt.orchestrator import get_image_prompt_orchestrator + + class BoomGateway: + def langchain_llm_for(self, *_a, **_kw): # noqa: ANN001 + raise RuntimeError("provider missing") + + monkeypatch.setattr( + "app.agents.image_prompt.orchestrator.settings.image_prompt_fallback_disabled", + False, + ) + monkeypatch.setattr("app.core.llm_gateway.LlmGateway", lambda: BoomGateway()) + + orch = get_image_prompt_orchestrator() + out = orch.build_cover_prompt( + chapter_title="T", + chapter_category="family", + context_excerpt="mountain lake", + ) + assert "mountain lake" in out["prompt"].lower() + + +def test_image_prompt_orchestrator_provider_failure_raises_when_disabled( + monkeypatch, +): + from app.agents.image_prompt.orchestrator import get_image_prompt_orchestrator + + class BoomGateway: + def langchain_llm_for(self, *_a, **_kw): # noqa: ANN001 + raise RuntimeError("provider missing") + + monkeypatch.setattr( + "app.agents.image_prompt.orchestrator.settings.image_prompt_fallback_disabled", + True, + ) + monkeypatch.setattr("app.core.llm_gateway.LlmGateway", lambda: BoomGateway()) + + with pytest.raises(RuntimeError, match="provider missing"): + get_image_prompt_orchestrator() diff --git a/api/tests/test_json_and_memory_utils.py b/api/tests/test_json_and_memory_utils.py index 4643a0f..d29782a 100644 --- a/api/tests/test_json_and_memory_utils.py +++ b/api/tests/test_json_and_memory_utils.py @@ -1,20 +1,14 @@ """JSON 载荷解析、证据格式化、Story 批量规划校验(纯函数)。""" -import pytest - from app.agents.chat.reply_limits import truncate_chat_segments - from app.agents.memoir.classification_agent import _normalize_llm_category -from app.agents.memoir.prompts import format_evidence_chunks_for_prompt -from app.features.memory.evidence_format import ( - format_evidence_chunks_for_prompt as format_evidence_from_memory, -) from app.agents.memoir.story_route_agent import ( StoryBatchPlan, StoryBatchPlanUnit, validate_story_batch_plan, ) from app.core.json_utils import extract_json_payload +from app.features.memory.evidence_format import format_evidence_chunks_for_prompt def test_extract_json_payload_strips_markdown_fence() -> None: @@ -34,29 +28,19 @@ def test_normalize_llm_category_strips_quotes() -> None: assert _normalize_llm_category("`beliefs`") == "beliefs" -def test_format_evidence_chunks_includes_timeline() -> None: +def test_format_evidence_chunks_uses_memory_formatter_without_timeline() -> None: ev = { "relevant_chunks": [{"content": "chunk1"}], "relevant_facts": [ {"subject": "我", "predicate": "生于", "object_json": "1950"} ], - "timeline_hints": [ - { - "id": "1", - "event_year": 1977, - "event_date": None, - "title": "恢复高考", - "description": "参加了考试", - } - ], "relevant_summaries": [], "relevant_stories": [], } out = format_evidence_chunks_for_prompt(ev) assert "chunk1" in out assert "1950" in out or "生于" in out - assert "1977" in out or "恢复高考" in out - assert format_evidence_from_memory(ev) == out + assert "恢复高考" not in out def test_validate_story_batch_plan_ok() -> None: diff --git a/api/tests/test_memoir_pipeline_optimization.py b/api/tests/test_memoir_pipeline_optimization.py index 67c1c80..fa1f0dd 100644 --- a/api/tests/test_memoir_pipeline_optimization.py +++ b/api/tests/test_memoir_pipeline_optimization.py @@ -106,7 +106,10 @@ def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> orch._prepare_batches_via_batch_llm = fail_batch orch.extraction_agent.extract = MagicMock( - return_value=ExtractionResult(detected_stage="childhood", slots={"toy": "ball"}) + return_value=ExtractionResult( + detected_stage="childhood", + slots={"place": "潍坊"}, + ) ) orch.classification_agent.classify = MagicMock( return_value=ChapterClassifyResult(category="childhood", llm_said_none=False) @@ -134,6 +137,52 @@ def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> assert "s1" in result.segment_chapter_category +def test_orchestrator_sequential_filters_invalid_slots( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Sequential fallback should match batch path slot validation.""" + monkeypatch.setattr( + "app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled", + False, + ) + + orch = MemoirOrchestrator() + orch.extraction_agent.extract = MagicMock( + return_value=ExtractionResult( + detected_stage="childhood", + slots={"place": "潍坊", "hallucinated": "bad"}, + ) + ) + orch.classification_agent.classify = MagicMock( + return_value=ChapterClassifyResult(category="childhood", llm_said_none=False) + ) + + st = MemoirStateSchema( + stage_order=["childhood"], + current_stage="childhood", + covered_stages=[], + slots={}, + ) + calls: list[tuple] = [] + + class _Seg: + id = "s1" + user_input_text = "我小时候在潍坊。" + + def update_slot(*args): + calls.append(args) + return st + + orch.prepare_batches( + segments=[_Seg()], + llm=MagicMock(), + get_or_create_state=lambda: st, + update_slot=update_slot, + ) + + assert calls == [("childhood", "place", "潍坊", ["s1"])] + + # --------------------------------------------------------------------------- # Memory enrichment decoupled from ingest # --------------------------------------------------------------------------- @@ -216,6 +265,33 @@ def test_resolve_append_target_forced_new_on_overflow() -> None: assert dsrc == "forced_new_due_to_append_limit" +def test_resolve_append_target_does_not_guardrail_route_fallback() -> None: + """No-LLM / parse fallback new_story decisions must not append by recency.""" + from app.features.memoir.story_pipeline_sync import _resolve_append_target + + session = MagicMock() + candidate = MagicMock() + candidate.id = "story-1" + + tid, existing, dsrc = _resolve_append_target( + session, + route_decision="new_story", + route_target_story_id=None, + user_id="u1", + chapter_category="childhood", + oral_norm="short text", + candidate_stories=[candidate], + story_meta={"story-1": {"char_count": 10, "version_count": 1}}, + decision_source="no_llm", + memoir_correlation_id=None, + ) + + assert tid is None + assert existing == "" + assert dsrc == "no_llm" + session.get.assert_not_called() + + # --------------------------------------------------------------------------- # _run_post_pipeline_commit helper # --------------------------------------------------------------------------- diff --git a/api/tests/test_memoir_route_defer.py b/api/tests/test_memoir_route_defer.py new file mode 100644 index 0000000..6fa51d4 --- /dev/null +++ b/api/tests/test_memoir_route_defer.py @@ -0,0 +1,437 @@ +"""Phase2 路由低置信延迟管线:deferred 池 / 唤醒 / 重试上限。""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import DEFAULT, MagicMock, patch + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker + +# 与 alembic/env.py 一致:注册全部 ORM,避免 relationship 解析失败 +from app.agents.memoir.story_route_agent import StoryRouteDecision +from app.agents.state_schema import MemoirStateSchema +from app.core.config import settings +from app.core.db import Base +from app.features.asset import models as _asset_models # noqa: F401 +from app.features.auth import models as _auth_models # noqa: F401 +from app.features.conversation import models as _conv_models # noqa: F401 +from app.features.conversation.models import Conversation, Segment +from app.features.memoir import models as _memoir_models # noqa: F401 +from app.features.memoir.story_pipeline_sync import ( + StoryPipelineResult, + run_story_pipeline_for_category_batch, +) +from app.features.memory import models as _memory_models # noqa: F401 +from app.features.payment import models as _payment_models # noqa: F401 +from app.features.story import models as _story_models # noqa: F401 +from app.features.user import models as _user_models # noqa: F401 +from app.features.user.models import User +from app.tasks.memoir_tasks import ( + _persist_phase2_route_defer, + _wake_deferred_segments_for_category, +) + + +@pytest.fixture +def sqlite_session_factory(): + engine = create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all( + engine, + tables=[ + User.__table__, + Conversation.__table__, + Segment.__table__, + ], + ) + yield sessionmaker(bind=engine, expire_on_commit=False, future=True) + engine.dispose() + + +def _seed_user_segment( + db, + *, + user_id: str, + conversation_id: str, + segment_id: str, + text: str = "我童年的事情很短暂", + topic_category: str = "childhood", +) -> Segment: + if not db.get(User, user_id): + db.add( + User( + id=user_id, + phone=f"p-{user_id[:8]}", + password_hash="x", + nickname="t", + ) + ) + if not db.get(Conversation, conversation_id): + db.add(Conversation(id=conversation_id, user_id=user_id)) + seg = Segment( + id=segment_id, + conversation_id=conversation_id, + user_input_text=text, + topic_category=topic_category, + narrated=False, + skip_narrative=False, + narrative_defer_count=0, + ) + db.add(seg) + db.commit() + return seg + + +def _patch_pipeline(plan_return, decide_return): + """统一 mock pipeline 内的 IO 与 LLM 依赖,便于聚焦路由分支。 + + 返回 ``(context_manager, route_agent_mock)``;进入 context 后由 ``patch.multiple`` + 生成的 mock dict 作为 ``mocks`` 提供给测试用例配置返回值与断言。 + """ + route_agent_mock = MagicMock() + route_agent_mock.plan_batch.return_value = plan_return + route_agent_mock.decide.return_value = decide_return + + return ( + patch.multiple( + "app.features.memoir.story_pipeline_sync", + list_active_stories_for_user_sync=DEFAULT, + StoryRouteAgent=DEFAULT, + NarrativeAgent=DEFAULT, + normalize_oral_for_memoir=DEFAULT, + ensure_chapter_story_link_sync=DEFAULT, + reorder_chapter_story_links_by_life_order_sync=DEFAULT, + mark_chapter_dirty_sync=DEFAULT, + chapter_needs_cover_enqueue=DEFAULT, + MemoirImageSettings=DEFAULT, + refresh_chapter_evidence_snapshot_with_retry_sync=DEFAULT, + create_story_with_version_sync=DEFAULT, + _ensure_chapter_record=DEFAULT, + ), + route_agent_mock, + ) + + +def _configure_pipeline_mocks(mocks: dict, route_agent_mock: MagicMock) -> None: + mocks["list_active_stories_for_user_sync"].return_value = [] + mocks["StoryRouteAgent"].return_value = route_agent_mock + mocks["normalize_oral_for_memoir"].side_effect = lambda text, **_: text + mocks["chapter_needs_cover_enqueue"].return_value = False + mocks["MemoirImageSettings"].from_env.return_value = MagicMock(enabled=False) + + +def _empty_state() -> MemoirStateSchema: + return MemoirStateSchema( + stage_order=["childhood"], + current_stage="childhood", + covered_stages=[], + slots={}, + ) + + +@pytest.mark.parametrize("reason", ["no_llm", "parse_error", "invalid_target"]) +def test_pipeline_defers_on_fallback_route_reason(reason: str) -> None: + """单段路由 fallback 时不写 chapter/story,返回 deferred 结果。""" + seg = SimpleNamespace(id="seg-defer-1", user_input_text="一句简短的口述") + decide_return = StoryRouteDecision( + decision="new_story", + new_story_title=None, + reason=reason, + ) + cm, route_agent_mock = _patch_pipeline( + plan_return=None, + decide_return=decide_return, + ) + with cm as mocks: + _configure_pipeline_mocks(mocks, route_agent_mock) + session = MagicMock() + exec_result = MagicMock() + exec_result.unique.return_value.scalar_one_or_none.return_value = None + session.execute.return_value = exec_result + + result = run_story_pipeline_for_category_batch( + session, + user_id="user-defer", + chapter_category="childhood", + category_segments=[seg], + state=_empty_state(), + user_profile="", + user_birth_year=None, + llm=object(), + memory_evidence={ + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "relevant_stories": [], + }, + ) + + assert isinstance(result, StoryPipelineResult) + assert result.deferred is True + assert result.chapter is None + assert result.dispatch_ids == set() + assert result.defer_reason == reason + assert result.defer_segment_ids == ["seg-defer-1"] + mocks["_ensure_chapter_record"].assert_not_called() + mocks["create_story_with_version_sync"].assert_not_called() + mocks["mark_chapter_dirty_sync"].assert_not_called() + route_agent_mock.decide.assert_called_once() + + +def test_pipeline_does_not_defer_when_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """关闭开关后,旧行为:直接写 new_story(不再延迟)。""" + monkeypatch.setattr(settings, "memoir_route_defer_enabled", False) + + seg = SimpleNamespace(id="seg-no-defer", user_input_text="一句简短的口述") + decide_return = StoryRouteDecision( + decision="new_story", + new_story_title=None, + reason="no_llm", + ) + cm, route_agent_mock = _patch_pipeline( + plan_return=None, + decide_return=decide_return, + ) + with cm as mocks: + _configure_pipeline_mocks(mocks, route_agent_mock) + chapter_stub = SimpleNamespace(id="chapter-1") + mocks["_ensure_chapter_record"].return_value = chapter_stub + story_stub = MagicMock() + story_stub.id = "story-x" + story_stub.current_version_id = None + mocks["create_story_with_version_sync"].return_value = story_stub + + # NarrativeAgent.generate_narrative 必须返回有效 JSON + nac_instance = mocks["NarrativeAgent"].return_value + nac_instance.generate_narrative.return_value = ( + '{"paragraphs": [{"content": "叙事正文段落足够长用于测试合并逻辑避免触发过短回退"}]}' + ) + + session = MagicMock() + exec_result = MagicMock() + exec_result.unique.return_value.scalar_one_or_none.return_value = None + session.execute.return_value = exec_result + + result = run_story_pipeline_for_category_batch( + session, + user_id="user-no-defer", + chapter_category="childhood", + category_segments=[seg], + state=_empty_state(), + user_profile="", + user_birth_year=None, + llm=object(), + memory_evidence={ + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "relevant_stories": [], + }, + ) + + assert isinstance(result, StoryPipelineResult) + assert result.deferred is False + assert result.chapter is chapter_stub + mocks["_ensure_chapter_record"].assert_called_once() + + +def test_pipeline_returns_result_object_for_normal_path() -> None: + """决策非 fallback 时,pipeline 仍按原路径执行并返回 StoryPipelineResult。""" + seg = SimpleNamespace(id="seg-ok", user_input_text="一段足够长的童年口述用于测试正常写入路径") + decide_return = StoryRouteDecision( + decision="new_story", + new_story_title="一个童年故事的新标题", + reason="ok", + ) + cm, route_agent_mock = _patch_pipeline( + plan_return=None, + decide_return=decide_return, + ) + with cm as mocks: + _configure_pipeline_mocks(mocks, route_agent_mock) + chapter_stub = SimpleNamespace(id="chapter-ok") + mocks["_ensure_chapter_record"].return_value = chapter_stub + story_stub = MagicMock() + story_stub.id = "story-ok" + story_stub.current_version_id = None + mocks["create_story_with_version_sync"].return_value = story_stub + + nac_instance = mocks["NarrativeAgent"].return_value + nac_instance.generate_narrative.return_value = ( + '{"paragraphs": [{"content": "叙事正文段落足够长用于测试合并逻辑避免触发过短回退"}]}' + ) + + session = MagicMock() + exec_result = MagicMock() + exec_result.unique.return_value.scalar_one_or_none.return_value = None + session.execute.return_value = exec_result + + result = run_story_pipeline_for_category_batch( + session, + user_id="user-ok", + chapter_category="childhood", + category_segments=[seg], + state=_empty_state(), + user_profile="", + user_birth_year=None, + llm=object(), + memory_evidence={ + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "relevant_stories": [], + }, + ) + + assert isinstance(result, StoryPipelineResult) + assert result.deferred is False + assert result.chapter is chapter_stub + + +def test_persist_phase2_route_defer_marks_segment_and_schedules_next( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """首次延迟:写入 defer 元数据并安排下一次 timeout(未达上限)。""" + monkeypatch.setattr(settings, "memoir_route_defer_seconds", 30.0) + monkeypatch.setattr(settings, "memoir_route_defer_max_attempts", 3) + + db = sqlite_session_factory() + seg = _seed_user_segment( + db, + user_id="u-defer-1", + conversation_id=str(uuid.uuid4()), + segment_id="seg-defer-x1", + ) + + with patch( + "app.tasks.memoir_tasks._schedule_phase2_timeout", + return_value="task-id-next", + ) as schedule_mock: + out = _persist_phase2_route_defer( + db, + user_id="u-defer-1", + chapter_category="childhood", + task_id="task-id-current", + memoir_correlation_id="cid-1", + defer_segment_ids=[seg.id], + defer_reason="no_llm", + phase2_started=0.0, + pipeline_elapsed=0.0, + lock_elapsed=0.0, + ) + + assert out["status"] == "deferred" + assert out["segments"] == 1 + assert out["saturated_count"] == 0 + schedule_mock.assert_called_once_with("u-defer-1", "childhood", "cid-1") + + refreshed = db.execute(select(Segment).where(Segment.id == seg.id)).scalar_one() + assert refreshed.narrative_defer_count == 1 + assert refreshed.narrative_defer_reason == "no_llm" + assert refreshed.narrative_deferred_until is not None + assert refreshed.narrative_last_attempt_at is not None + assert refreshed.narrated is False + assert refreshed.processed is False + + +def test_persist_phase2_route_defer_stops_scheduling_at_max_attempts( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """达到 max_attempts 后不再继续派发 timeout,segment 仍保留 defer 元数据。""" + monkeypatch.setattr(settings, "memoir_route_defer_seconds", 30.0) + monkeypatch.setattr(settings, "memoir_route_defer_max_attempts", 2) + + db = sqlite_session_factory() + seg = _seed_user_segment( + db, + user_id="u-defer-max", + conversation_id=str(uuid.uuid4()), + segment_id="seg-defer-max-1", + ) + seg.narrative_defer_count = 1 + db.commit() + + with patch( + "app.tasks.memoir_tasks._schedule_phase2_timeout", + return_value="should-not-be-called", + ) as schedule_mock: + out = _persist_phase2_route_defer( + db, + user_id="u-defer-max", + chapter_category="childhood", + task_id="task-id-current", + memoir_correlation_id="cid-2", + defer_segment_ids=[seg.id], + defer_reason="parse_error", + phase2_started=0.0, + pipeline_elapsed=0.0, + lock_elapsed=0.0, + ) + + assert out["status"] == "deferred" + assert out["saturated_count"] == 1 + schedule_mock.assert_not_called() + + refreshed = db.execute(select(Segment).where(Segment.id == seg.id)).scalar_one() + assert refreshed.narrative_defer_count == 2 + # 达上限后不设 deferred_until,需要等待新素材唤醒;此时 segment 仍可被下次 Phase2 消费 + assert refreshed.narrative_deferred_until is None + assert refreshed.narrative_defer_reason == "parse_error" + + +def test_wake_deferred_segments_clears_defer_metadata( + sqlite_session_factory, +) -> None: + """新素材到达时清空同类目下既有 defer 元数据,并保留另一类目不变。""" + db = sqlite_session_factory() + user_id = "u-wake" + conv_id = str(uuid.uuid4()) + seg_a = _seed_user_segment( + db, + user_id=user_id, + conversation_id=conv_id, + segment_id="seg-wake-1", + topic_category="childhood", + ) + seg_other = _seed_user_segment( + db, + user_id=user_id, + conversation_id=conv_id, + segment_id="seg-other", + topic_category="education", + ) + seg_a.narrative_defer_count = 2 + seg_a.narrative_defer_reason = "parse_error" + seg_a.narrative_deferred_until = datetime.now(timezone.utc) + timedelta(minutes=5) + seg_other.narrative_defer_count = 1 + seg_other.narrative_defer_reason = "no_llm" + seg_other.narrative_deferred_until = datetime.now(timezone.utc) + timedelta( + minutes=5 + ) + db.commit() + + woke = _wake_deferred_segments_for_category(db, user_id, "childhood") + db.commit() + + refreshed_a = db.execute( + select(Segment).where(Segment.id == seg_a.id) + ).scalar_one() + refreshed_other = db.execute( + select(Segment).where(Segment.id == seg_other.id) + ).scalar_one() + + assert woke == 1 + assert refreshed_a.narrative_deferred_until is None + assert refreshed_a.narrative_defer_count == 0 + assert refreshed_a.narrative_defer_reason is None + # 其它类目不应被波及 + assert refreshed_other.narrative_deferred_until is not None + assert refreshed_other.narrative_defer_count == 1 + assert refreshed_other.narrative_defer_reason == "no_llm" diff --git a/api/tests/test_memoir_two_phase.py b/api/tests/test_memoir_two_phase.py index 702c1fe..e1a6ef5 100644 --- a/api/tests/test_memoir_two_phase.py +++ b/api/tests/test_memoir_two_phase.py @@ -16,7 +16,7 @@ def test_segment_chapter_category_populated() -> None: orch = MemoirOrchestrator() orch.extraction_agent.extract = MagicMock( return_value=ExtractionResult( - detected_stage="childhood", slots={"toy": "布娃娃"} + detected_stage="childhood", slots={"daily_life": "玩布娃娃"} ) ) orch.classification_agent.classify = MagicMock( diff --git a/api/tests/test_memory_boundaries.py b/api/tests/test_memory_boundaries.py index 1baf7c9..7f12e07 100644 --- a/api/tests/test_memory_boundaries.py +++ b/api/tests/test_memory_boundaries.py @@ -9,6 +9,17 @@ from app.features.memory.prompt_adapter import MemoryPromptAdapter from app.features.memory.runtime_types import MemoryEvidenceBundle +def test_chunk_transcript_applies_configured_overlap() -> None: + from app.features.memory.chunker import chunk_transcript + + text = "".join(str(i % 10) for i in range(250)) + chunks = chunk_transcript(text, max_chars=100, overlap_chars=20) + + assert len(chunks) >= 3 + assert chunks[0][-20:] == chunks[1][:20] + assert chunks[1][-20:] == chunks[2][:20] + + def test_memory_evidence_bundle_and_prompt_adapter_contract() -> None: evidence = MemoryEvidenceBundle.from_mapping( { @@ -17,7 +28,6 @@ def test_memory_evidence_bundle_and_prompt_adapter_contract() -> None: ], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } ) @@ -52,7 +62,6 @@ async def test_memory_retrieval_service_delegates_to_retriever( "relevant_chunks": [{"id": "c1", "content": "chunk"}], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } @@ -117,17 +126,18 @@ async def test_memory_ingest_service_commits_before_enrichment( events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"])) return FakeRow(f"chunk-{kwargs['chunk_index']}") - async def fake_update_chunk_embedding(db, chunk_id, emb): - events.append(("update_embedding", chunk_id, tuple(emb))) + class FakeEmbeddingService: + def __init__(self, db, *, embedding_provider=None) -> None: + events.append(("embedding_service", embedding_provider is not None)) + + async def embed_source(self, user_id: str, source_id: str) -> dict: + events.append(("embed_source", user_id, source_id)) + return {"status": "success", "vectors_written": 2} monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a", "b"]) monkeypatch.setattr(ingest_mod, "create_source", fake_create_source) monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk) - monkeypatch.setattr( - ingest_mod, - "update_chunk_embedding", - fake_update_chunk_embedding, - ) + monkeypatch.setattr(ingest_mod, "MemoryEmbeddingService", FakeEmbeddingService) source_id = await MemoryIngestService( FakeDb(), @@ -139,9 +149,150 @@ async def test_memory_ingest_service_commits_before_enrichment( assert events.index(("commit",)) < events.index( ("schedule", "user-1", "source-1") ) - assert ("embed_texts", ("a", "b")) in events - assert ("update_embedding", "chunk-0", (1.0,)) in events - assert ("update_embedding", "chunk-1", (2.0,)) in events + assert ("embed_source", "user-1", "source-1") in events + + +@pytest.mark.asyncio +async def test_memory_ingest_succeeds_and_retries_when_embedding_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from app.features.memory import ingest_service as ingest_mod + from app.features.memory.ingest_service import MemoryIngestService + + events: list[tuple] = [] + + @dataclass + class FakeRow: + id: str + + class FakeDb: + async def flush(self) -> None: + events.append(("flush",)) + + async def commit(self) -> None: + events.append(("commit",)) + + class FakeEmbeddingService: + def __init__(self, db, *, embedding_provider=None) -> None: + pass + + async def embed_source(self, user_id: str, source_id: str) -> dict: + events.append(("embed_source", user_id, source_id)) + return {"status": "failed", "error": "upstream_timeout"} + + class FakeEmbeddingScheduler: + def schedule(self, request) -> str: + events.append(("embed_retry", request.user_id, request.source_id)) + return "embed-retry-1" + + class FakeEmbedding: + def is_available(self) -> bool: + return True + + class FakeEnrichmentScheduler: + def schedule(self, request) -> str: + events.append(("enrich", request.user_id, request.source_id)) + return "enrich-1" + + async def fake_create_source(db, **kwargs): + events.append(("create_source", kwargs["user_id"], kwargs["conversation_id"])) + return FakeRow("source-1") + + async def fake_create_chunk(db, **kwargs): + events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"])) + return FakeRow(f"chunk-{kwargs['chunk_index']}") + + monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a"]) + monkeypatch.setattr(ingest_mod, "create_source", fake_create_source) + monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk) + monkeypatch.setattr(ingest_mod, "MemoryEmbeddingService", FakeEmbeddingService) + + source_id = await MemoryIngestService( + FakeDb(), + embedding_provider=FakeEmbedding(), + embedding_scheduler=FakeEmbeddingScheduler(), + enrichment_scheduler=FakeEnrichmentScheduler(), + ).ingest_transcript("user-1", "conv-1", "hello") + + assert source_id == "source-1" + assert ("embed_retry", "user-1", "source-1") in events + assert ("enrich", "user-1", "source-1") in events + assert events.index(("commit",)) < events.index( + ("embed_source", "user-1", "source-1") + ) + + +@pytest.mark.asyncio +async def test_exclude_chunk_stales_derived_facts(monkeypatch: pytest.MonkeyPatch) -> None: + from app.features.memory import service as service_mod + from app.features.memory.service import MemoryService + + events: list[tuple] = [] + + class FakeDb: + async def commit(self) -> None: + events.append(("commit",)) + + async def fake_set_chunk_excluded(db, chunk_id, user_id, excluded): + events.append(("set_excluded", chunk_id, user_id, excluded)) + return True + + async def fake_stale(db, *, user_id, chunk_id): + events.append(("stale_facts", user_id, chunk_id)) + return 2 + + async def fake_curation(db, **kwargs): + events.append(("curation", kwargs)) + + monkeypatch.setattr(service_mod, "set_chunk_excluded", fake_set_chunk_excluded) + monkeypatch.setattr( + service_mod, + "mark_facts_stale_for_excluded_chunk", + fake_stale, + ) + monkeypatch.setattr(service_mod, "create_curation_action", fake_curation) + + ok = await MemoryService(FakeDb()).exclude_chunk( + "user-1", + "chunk-1", + reason="wrong memory", + ) + + assert ok is True + assert ("stale_facts", "user-1", "chunk-1") in events + curation = [ev for ev in events if ev[0] == "curation"][0][1] + assert curation["details"] == { + "reason": "wrong memory", + "staled_fact_count": 2, + } + + +@pytest.mark.asyncio +async def test_restore_chunk_records_reenrichment_policy( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from app.features.memory import service as service_mod + from app.features.memory.service import MemoryService + + captured: list[dict] = [] + + class FakeDb: + async def commit(self) -> None: + pass + + async def fake_set_chunk_excluded(db, chunk_id, user_id, excluded): + return True + + async def fake_curation(db, **kwargs): + captured.append(kwargs) + + monkeypatch.setattr(service_mod, "set_chunk_excluded", fake_set_chunk_excluded) + monkeypatch.setattr(service_mod, "create_curation_action", fake_curation) + + ok = await MemoryService(FakeDb()).restore_chunk("user-1", "chunk-1") + + assert ok is True + assert captured[0]["details"] == {"fact_restore_policy": "requires_reenrichment"} def test_memory_single_chain_architecture_guard() -> None: @@ -158,6 +309,10 @@ def test_memory_single_chain_architecture_guard() -> None: "memory_fact_search_use_recent" + "_fallback", "memory_evidence_empty_query_include" + "_rolling", "_interview_meta" + "_store", + "timeline" + "_hints", + "parse_json" + "_payload", + "from app.agents.memoir.prompts import " + "format_evidence_chunks_for_prompt", ] roots = [ repo_root / "api" / "app", diff --git a/api/tests/test_memory_enrichment_baseline.py b/api/tests/test_memory_enrichment_baseline.py index a360dee..76b1460 100644 --- a/api/tests/test_memory_enrichment_baseline.py +++ b/api/tests/test_memory_enrichment_baseline.py @@ -7,7 +7,7 @@ from types import SimpleNamespace import pytest from app.features.memory.enrichment import enrich_memory_after_ingest_async -from app.features.memory.llm_schemas import EnrichmentPayload, parse_json_payload +from app.features.memory.llm_schemas import EnrichmentPayload from app.features.memory.models import MemorySource from app.features.user.models import User @@ -19,8 +19,7 @@ def test_enrichment_payload_roundtrip() -> None: '"object_json":{"value":"北京","approximate_era":"1990年代"},' '"confidence":0.85,"source_chunk_id":"ch-1"}]}' ) - p = parse_json_payload(raw, EnrichmentPayload) - assert p is not None + p = EnrichmentPayload.model_validate_json(raw) assert p.summary == "要点摘要" assert len(p.facts) == 1 assert p.facts[0].subject == "王伟" @@ -36,16 +35,27 @@ async def test_enrich_memory_after_ingest_async_single_llm_call( invoke_count = {"n": 0} - async def fake_invoke(llm, prompt, max_tokens, agent): + async def fake_run(llm, numbered, narrator_label): invoke_count["n"] += 1 - assert agent == "memory.enrichment" - return ( - '{"summary":"本轮要点",' - '"facts":[{"fact_type":"event","subject":"王伟","predicate":"住",' - '"object_json":{"value":"上海"},"confidence":0.8,"source_chunk_id":"ch1"}]}' + assert "[chunk_id=ch1]" in numbered + assert narrator_label == "老王" + return EnrichmentPayload.model_validate( + { + "summary": "本轮要点", + "facts": [ + { + "fact_type": "event", + "subject": "王伟", + "predicate": "住", + "object_json": {"value": "上海"}, + "confidence": 0.8, + "source_chunk_id": "ch1", + } + ], + } ) - monkeypatch.setattr(mod, "ainvoke_json_object", fake_invoke) + monkeypatch.setattr(mod, "_run_enrichment_llm_async", fake_run) summaries: list[dict] = [] facts: list[dict] = [] @@ -74,7 +84,7 @@ async def test_enrich_memory_after_ingest_async_single_llm_call( if model is User and key == "u1": return SimpleNamespace(nickname="老王") if model is MemorySource and key == "src-1": - return SimpleNamespace(lineage_json=None) + return SimpleNamespace(user_id="u1", lineage_json=None) return None async def execute(self, _stmt): @@ -105,10 +115,10 @@ async def test_enrich_memory_skips_when_parse_returns_none( monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", True) - async def fake_invoke(*_args, **_kwargs): - return "{not json" + async def fake_run(*_args, **_kwargs): + return None - monkeypatch.setattr(mod, "ainvoke_json_object", fake_invoke) + monkeypatch.setattr(mod, "_run_enrichment_llm_async", fake_run) called = {"summary": False, "fact": False} async def capture_summary(*_args, **_kwargs): @@ -135,7 +145,7 @@ async def test_enrich_memory_skips_when_parse_returns_none( if model is User and key == "u": return None if model is MemorySource and key == "s": - return SimpleNamespace(lineage_json=None) + return SimpleNamespace(user_id="u", lineage_json=None) return None async def execute(self, _stmt): diff --git a/api/tests/test_memory_evidence.py b/api/tests/test_memory_evidence.py index 2ad4479..db3d3f7 100644 --- a/api/tests/test_memory_evidence.py +++ b/api/tests/test_memory_evidence.py @@ -7,7 +7,6 @@ from app.features.memory.evidence import ( EMPTY_EVIDENCE_BUNDLE, _facts_to_dicts, _stories_to_dicts, - _timeline_to_dicts, retrieve_evidence_bundle_async, ) from app.features.memory.evidence_format import format_evidence_chunks_for_chat_prompt @@ -19,7 +18,6 @@ def test_empty_evidence_bundle_keys() -> None: "relevant_chunks", "relevant_summaries", "relevant_facts", - "timeline_hints", "relevant_stories", } @@ -31,7 +29,6 @@ def test_evidence_bundle_model_accepts_dict() -> None: def test_format_helpers_empty() -> None: assert _facts_to_dicts([]) == [] - assert _timeline_to_dicts([]) == [] assert _stories_to_dicts([]) == [] @@ -42,7 +39,6 @@ def test_format_evidence_chunks_for_chat_prompt_reframes_and_labels() -> None: ], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } text = format_evidence_chunks_for_chat_prompt(evidence) @@ -73,7 +69,6 @@ def test_slice_interview_memory_retrieval_not_equal_inject_dismissive(): ], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } s = slice_interview_memory(evidence, "哈哈,早就不会了") @@ -92,7 +87,6 @@ def test_slice_interview_memory_minimal_inject_when_aligned(): ], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } s = slice_interview_memory(evidence, "那次排练其实挺紧张的,灯光一打我就忘词。") @@ -111,7 +105,6 @@ def test_slice_interview_memory_keeps_first_person_but_marks_ownership(): ], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } s = slice_interview_memory(evidence, "那条河一到夏天就特别热闹,我现在都记得。") @@ -129,7 +122,6 @@ def test_slice_interview_memory_suppresses_long_new_topic(): ], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], } long_msg = "我今天想随便聊聊工作里的事,项目压力很大。" * 6 @@ -153,7 +145,6 @@ async def test_retrieve_evidence_bundle_async_non_empty_merges_precomputed_chunk "object_json": {}, } ], - "timeline_hints": [], "relevant_summaries": [ { "id": "s1", diff --git a/api/tests/test_profile_agent_gateway.py b/api/tests/test_profile_agent_gateway.py new file mode 100644 index 0000000..5e7fd71 --- /dev/null +++ b/api/tests/test_profile_agent_gateway.py @@ -0,0 +1,85 @@ +"""ProfileAgent LLM gateway injection regression tests.""" + +from __future__ import annotations + +import json +from types import SimpleNamespace + +import pytest + +from app.agents.chat.profile_agent import ProfileAgent + + +class _Response: + def __init__(self, content: str) -> None: + self.content = content + + +class _BoundJsonLlm: + async def ainvoke(self, _prompt: str) -> _Response: + return _Response( + json.dumps( + { + "birth_year": 1988, + "birth_place": "杭州", + "grew_up_place": "杭州", + "occupation": "工程师", + } + ) + ) + + +class _JsonLlm: + def bind(self, **_kwargs) -> _BoundJsonLlm: # noqa: ANN003 + return _BoundJsonLlm() + + +class _Provider: + langchain_llm = _JsonLlm() + + def __init__(self) -> None: + self.messages: list[dict] = [] + + async def complete(self, messages: list[dict], **_kwargs) -> str: # noqa: ANN003 + self.messages = messages + return "谢谢分享!还能再说说吗?" + + async def stream(self, *_args, **_kwargs): # noqa: ANN003 + if False: + yield "" + + +@pytest.mark.asyncio +async def test_profile_agent_llm_provider_injection_covers_chat_and_json( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def fake_history(*_args, **_kwargs): + return SimpleNamespace(window=[], turn_total=0) + + monkeypatch.setattr( + "app.agents.chat.profile_agent.get_history_with_window", + fake_history, + ) + provider = _Provider() + agent = ProfileAgent(llm_provider=provider) + + extracted = await agent.extract_profile_from_message( + "我是一名工程师,1988 年出生在杭州。", + ["birth_year", "birth_place", "occupation"], + ) + followup = await agent.generate_profile_followup( + conversation_id="c1", + user_message="我在杭州长大。", + missing_fields=["grew_up_place"], + filled_fields={"birth_year": "1988"}, + ) + + assert extracted == { + "birth_year": 1988, + "birth_place": "杭州", + "grew_up_place": "杭州", + "occupation": "工程师", + } + assert followup + assert provider.messages + assert provider.messages[0]["role"] == "system" diff --git a/api/tests/test_stage_validation.py b/api/tests/test_stage_validation.py index 39c471a..11bd76c 100644 --- a/api/tests/test_stage_validation.py +++ b/api/tests/test_stage_validation.py @@ -8,6 +8,7 @@ from app.agents.memoir.extraction_agent import ExtractionAgent from app.agents.memoir.schemas import StateExtractionOutput from app.agents.stage_constants import ( chat_bucket, + filter_stage_slots, normalize_chapter_category, normalize_chat_stage, ) @@ -41,6 +42,13 @@ def test_chat_bucket() -> None: assert chat_bucket("beliefs") == "belief" +def test_filter_stage_slots_uses_canonical_keys() -> None: + assert filter_stage_slots( + "childhood", + {"place": "潍坊", "toy": "ball"}, + ) == {"place": "潍坊"} + + def test_extraction_agent_normalizes_detected_stage( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/api/tests/test_state_service_batch_stage_policy.py b/api/tests/test_state_service_batch_stage_policy.py index 704c8a8..eb6f4ef 100644 --- a/api/tests/test_state_service_batch_stage_policy.py +++ b/api/tests/test_state_service_batch_stage_policy.py @@ -155,7 +155,7 @@ def test_update_slot_sync_batch_flag_true_same_bucket_updates_row( update_slot_sync( uid, "career_achievement", - "peak", + "growth", "won prize", ["s2"], db, @@ -165,7 +165,32 @@ def test_update_slot_sync_batch_flag_true_same_bucket_updates_row( select(MemoirStateModel).where(MemoirStateModel.user_id == uid) ).scalar_one() assert st.current_stage == "career" - assert st.slots.get("career", {}).get("peak") is not None + assert st.slots.get("career", {}).get("growth") is not None + + +def test_update_slot_sync_ignores_invalid_slot_name( + sqlite_session_factory, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + uid = "u-invalid-slot" + db = sqlite_session_factory() + _add_user_and_state(db, user_id=uid, current_stage="childhood") + + update_slot_sync( + uid, + "childhood", + "made_up_key", + "bad", + ["s-bad"], + db, + memoir_batch=True, + ) + st = db.execute( + select(MemoirStateModel).where(MemoirStateModel.user_id == uid) + ).scalar_one() + assert "made_up_key" not in (st.slots.get("childhood") or {}) + assert st.current_stage == "childhood" def test_update_slot_sync_batch_flag_true_cross_bucket_unchanged( diff --git a/api/tests/test_story_route_oral_invariant.py b/api/tests/test_story_route_oral_invariant.py index 2642679..9a198bc 100644 --- a/api/tests/test_story_route_oral_invariant.py +++ b/api/tests/test_story_route_oral_invariant.py @@ -4,7 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch -from app.agents.memoir.prompts import format_evidence_chunks_for_prompt from app.agents.memoir.story_route_agent import StoryRouteDecision from app.agents.state_schema import MemoirStateSchema from app.features.asset import models as _asset_models # noqa: F401 @@ -15,6 +14,7 @@ from app.features.memoir.story_pipeline_sync import ( run_story_pipeline_for_category_batch, ) from app.features.memory import models as _memory_models # noqa: F401 +from app.features.memory.evidence_format import format_evidence_chunks_for_prompt from app.features.payment import models as _payment_models # noqa: F401 from app.features.story import models as _story_models # noqa: F401 from app.features.user import models as _user_models # noqa: F401 @@ -47,7 +47,6 @@ def test_single_segment_decide_receives_only_combined_text_not_evidence() -> Non } ], "relevant_facts": [{"subject": "X", "predicate": "y", "object_json": {}}], - "timeline_hints": [], "relevant_stories": [], } evidence_formatted = format_evidence_chunks_for_prompt(evidence_payload) @@ -236,7 +235,6 @@ def test_decide_receives_only_same_stage_story_candidates() -> None: "relevant_chunks": [], "relevant_summaries": [], "relevant_facts": [], - "timeline_hints": [], "relevant_stories": [], }, ) diff --git a/api/tests/test_story_route_prompts_and_behavior.py b/api/tests/test_story_route_prompts_and_behavior.py index 8b7616e..e5944d1 100644 --- a/api/tests/test_story_route_prompts_and_behavior.py +++ b/api/tests/test_story_route_prompts_and_behavior.py @@ -134,7 +134,7 @@ def test_decide_career_mock_llm_new_story_and_prompt_episodic(): assert "经历叙事" in captured["prompt"] -def test_decide_invalid_target_falls_back_to_default_append(): +def test_decide_invalid_target_falls_back_to_new_story(): def fake_llm_json(_llm, _prompt: str, _schema: object, **_kwargs): return StoryRouteDecision( decision="append_story", @@ -163,12 +163,12 @@ def test_decide_invalid_target_falls_back_to_default_append(): valid_story_ids={"good"}, story_meta={"good": {"char_count": 2, "version_count": 1}}, ) - assert d.decision == "append_story" - assert d.target_story_id == "good" - assert d.reason == "invalid_target_default_append" + assert d.decision == "new_story" + assert d.target_story_id is None + assert d.reason == "invalid_target" -def test_decide_no_llm_defaults_append_when_candidates_exist(): +def test_decide_no_llm_defaults_new_story_when_candidates_exist(): cand = SimpleNamespace( id="s-default", title="求学", @@ -186,8 +186,39 @@ def test_decide_no_llm_defaults_append_when_candidates_exist(): valid_story_ids={"s-default"}, story_meta={"s-default": {"char_count": 4, "version_count": 1}}, ) - assert d.decision == "append_story" - assert d.target_story_id == "s-default" + assert d.decision == "new_story" + assert d.target_story_id is None + assert d.reason == "no_llm" + + +def test_decide_parse_error_fallback_defaults_new_story(): + def fake_llm_json(_llm, _prompt: str, _schema: object, **kwargs): + return kwargs["fallback_factory"]() + + cand = SimpleNamespace( + id="s-default", + title="求学", + summary="y" * 40, + canonical_markdown="本科经历", + updated_at=datetime(2025, 2, 1, tzinfo=timezone.utc), + chapter_links=[], + ) + with patch( + "app.agents.memoir.story_route_agent.llm_json_call", + side_effect=fake_llm_json, + ): + d = StoryRouteAgent().decide( + chapter_category="education", + chapter_title="教育", + batch_transcript="后来又考研。", + candidate_stories=[cand], + llm=MagicMock(), + valid_story_ids={"s-default"}, + story_meta={"s-default": {"char_count": 4, "version_count": 1}}, + ) + assert d.decision == "new_story" + assert d.target_story_id is None + assert d.reason == "parse_error" def test_plan_batch_merges_consecutive_new_story_units(): diff --git a/app-expo/.env.example b/app-expo/.env.example index b1fa4b8..8774536 100644 --- a/app-expo/.env.example +++ b/app-expo/.env.example @@ -4,6 +4,9 @@ # CI:GitHub Actions 在构建 APK 前会按分支调用 use-env(main → staging,tag → production)。 # # 变量在构建时注入;修改后需重新 prebuild/打包客户端。 +# +# 助手朗读:无独立 EXPO_PUBLIC_* TTS 开关。会话页顶栏在每轮 WebSocket 中带 `tts_this_turn`; +# 服务端是否具备合成能力见 api/.env 中 ENABLE_TTS 等(模板见 api/.env.example)。 EXPO_PUBLIC_API_URL=https://your-api.example.com EXPO_PUBLIC_WS_URL=wss://your-api.example.com diff --git a/app-expo/.env.production b/app-expo/.env.production index ca37760..d8b30e0 100644 --- a/app-expo/.env.production +++ b/app-expo/.env.production @@ -1,2 +1,3 @@ +# 仅 API/WS 基址;TTS 每轮开关由运行时 WS payload 与服务端 ENABLE_TTS 控制(见 api/.env.example)。 EXPO_PUBLIC_API_URL=https://lifecho.worldsplats.com EXPO_PUBLIC_WS_URL=wss://lifecho.worldsplats.com diff --git a/app-expo/app.config.ts b/app-expo/app.config.ts index 1310211..2755118 100644 --- a/app-expo/app.config.ts +++ b/app-expo/app.config.ts @@ -28,6 +28,8 @@ const LOCALES: Record = { const SUPPORTED_LOCALES = ['zh', 'en'] as const; const PRIMARY_LOCALE = process.env.EXPO_PUBLIC_PRIMARY_LOCALE ?? 'zh'; +const API_BASE_URL = process.env.EXPO_PUBLIC_API_URL ?? ''; +const ALLOW_ANDROID_CLEARTEXT_TRAFFIC = API_BASE_URL.startsWith('http://'); const PERMISSION_FALLBACKS: Record = { microphone: 'Allow $(PRODUCT_NAME) to access your microphone.', @@ -149,11 +151,15 @@ export default ({ config }: ConfigContext): ExpoConfig => { plugins: [ // CI/local release: android/app/keystore.properties + store file → release signing; -PversionName/-PversionCode './plugins/withAndroidReleaseSigning', + [ + './plugins/withAndroidCleartextTraffic', + { enabled: ALLOW_ANDROID_CLEARTEXT_TRAFFIC }, + ], 'expo-router', [ 'expo-splash-screen', { - // 与 android.adaptiveIcon.backgroundColor、品牌浅紫一致(见 scripts/generate-app-icon.sh) + // 与 android.adaptiveIcon.backgroundColor、品牌浅紫一致(见 scripts/generate-app-icon.sh,源图为 assets/logo.png) backgroundColor: '#E6F4FE', image: './assets/images/splash-icon.png', resizeMode: 'contain', diff --git a/app-expo/assets/images/android-icon-foreground.png b/app-expo/assets/images/android-icon-foreground.png index 6f1ca32..4a1b68e 100644 Binary files a/app-expo/assets/images/android-icon-foreground.png and b/app-expo/assets/images/android-icon-foreground.png differ diff --git a/app-expo/assets/images/android-icon-monochrome.png b/app-expo/assets/images/android-icon-monochrome.png index a1dd718..1f69bf2 100644 Binary files a/app-expo/assets/images/android-icon-monochrome.png and b/app-expo/assets/images/android-icon-monochrome.png differ diff --git a/app-expo/assets/images/favicon.png b/app-expo/assets/images/favicon.png index 6d2cab3..61e3d5f 100644 Binary files a/app-expo/assets/images/favicon.png and b/app-expo/assets/images/favicon.png differ diff --git a/app-expo/assets/images/icon-alpha.png b/app-expo/assets/images/icon-alpha.png new file mode 100644 index 0000000..f858a2e Binary files /dev/null and b/app-expo/assets/images/icon-alpha.png differ diff --git a/app-expo/assets/images/icon.png b/app-expo/assets/images/icon.png index 5da1cfa..4a1b68e 100644 Binary files a/app-expo/assets/images/icon.png and b/app-expo/assets/images/icon.png differ diff --git a/app-expo/assets/images/splash-icon.png b/app-expo/assets/images/splash-icon.png index 44a8862..2b6210a 100644 Binary files a/app-expo/assets/images/splash-icon.png and b/app-expo/assets/images/splash-icon.png differ diff --git a/app-expo/assets/logo.png b/app-expo/assets/logo.png new file mode 100644 index 0000000..4a1b68e Binary files /dev/null and b/app-expo/assets/logo.png differ diff --git a/app-expo/plugins/withAndroidCleartextTraffic.js b/app-expo/plugins/withAndroidCleartextTraffic.js new file mode 100644 index 0000000..b36af79 --- /dev/null +++ b/app-expo/plugins/withAndroidCleartextTraffic.js @@ -0,0 +1,32 @@ +// @ts-check +/** + * Toggle Android cleartext HTTP traffic from Expo env. + * + * Staging may use an IP:port HTTP endpoint while production remains HTTPS. + */ +const { withAndroidManifest } = require('@expo/config-plugins'); + +/** + * @param {import('expo/config').ExpoConfig} config + * @param {{ enabled?: boolean }} props + */ +function withAndroidCleartextTraffic(config, props = {}) { + return withAndroidManifest(config, (mod) => { + const mainApplication = mod.modResults.manifest.application?.[0]; + + if (!mainApplication) { + throw new Error( + '[withAndroidCleartextTraffic] Main application not found in AndroidManifest.xml.', + ); + } + + mainApplication.$ = mainApplication.$ ?? {}; + mainApplication.$['android:usesCleartextTraffic'] = props.enabled + ? 'true' + : 'false'; + + return mod; + }); +} + +module.exports = withAndroidCleartextTraffic; diff --git a/app-expo/scripts/generate-app-icon.sh b/app-expo/scripts/generate-app-icon.sh index 5bedbc4..fe2650c 100755 --- a/app-expo/scripts/generate-app-icon.sh +++ b/app-expo/scripts/generate-app-icon.sh @@ -1,23 +1,27 @@ #!/usr/bin/env bash -# 一次性:用 ImageMagick 7+(magick)从 assets/life-echo-logo.jpg 生成各平台 PNG。 +# 用 ImageMagick 7+(magick)从 assets/logo.png(或旧的 life-echo-logo.jpg)生成各平台 PNG。 # 依赖:brew install imagemagick # 用法:在 app-expo 目录执行 ./scripts/generate-app-icon.sh set -euo pipefail ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -SRC="$ROOT/assets/life-echo-logo.jpg" OUT="$ROOT/assets/images" # 与 Android adaptiveIcon.backgroundColor、开屏底色一致(新 logo 浅紫系) BRAND_BG="#E6F4FE" +if [[ -f "$ROOT/assets/logo.png" ]]; then + SRC="$ROOT/assets/logo.png" +elif [[ -f "$ROOT/assets/life-echo-logo.jpg" ]]; then + SRC="$ROOT/assets/life-echo-logo.jpg" +else + echo "missing: $ROOT/assets/logo.png or $ROOT/assets/life-echo-logo.jpg" >&2 + exit 1 +fi + TMP="$(mktemp -d)" trap 'rm -rf "$TMP"' EXIT -if [[ ! -f "$SRC" ]]; then - echo "missing: $SRC" >&2 - exit 1 -fi if ! command -v magick >/dev/null; then echo "need: magick (ImageMagick 7)" >&2 exit 1 @@ -25,8 +29,12 @@ fi mkdir -p "$OUT" -# 近白背景 → 透明(JPG 白底) -magick "$SRC" -fuzz 12% -transparent white PNG32:"$TMP/fg.png" +# 底图抠透明:JPG 近白底;PNG 主站 logo 为品牌浅底 +if [[ "$SRC" == *.jpg ]] || [[ "$SRC" == *.jpeg ]]; then + magick "$SRC" -fuzz 12% -transparent white PNG32:"$TMP/fg.png" +else + magick "$SRC" -fuzz 12% -transparent "${BRAND_BG}" PNG32:"$TMP/fg.png" +fi # Android 自适应前景(透明底,1024,内容约 78% 安全区) magick "$TMP/fg.png" -resize '800x800>' -background none -gravity center -extent 1024x1024 \ @@ -41,8 +49,12 @@ magick "$OUT/android-icon-foreground.png" -alpha extract -fill white -colorize 1 PNG32:"$OUT/android-icon-monochrome.png" # 开屏图:仅透明 logo(底色由 expo-splash-screen backgroundColor 填充) -magick "$SRC" -fuzz 12% -transparent white -resize '400x400>' PNG32:"$OUT/splash-icon.png" +magick "$TMP/fg.png" -resize '400x400>' PNG32:"$OUT/splash-icon.png" magick "$OUT/icon.png" -resize 48x48 PNG32:"$OUT/favicon.png" -echo "OK → $OUT/icon.png, android-icon-foreground.png, android-icon-monochrome.png, splash-icon.png, favicon.png" +# 带 Alpha 的 logo(与 app 内资源一致时可自行纳入版本库) +magick "$TMP/fg.png" -resize '1024x1024>' -background none -gravity center -extent 1024x1024 \ + PNG32:"$OUT/icon-alpha.png" + +echo "OK → $OUT/icon.png, icon-alpha.png, android-icon-foreground.png, android-icon-monochrome.png, splash-icon.png, favicon.png" diff --git a/app-expo/src/app/(main)/conversation/[id].tsx b/app-expo/src/app/(main)/conversation/[id].tsx index 66ae520..2cf7bed 100644 --- a/app-expo/src/app/(main)/conversation/[id].tsx +++ b/app-expo/src/app/(main)/conversation/[id].tsx @@ -5,7 +5,6 @@ import { Pause, Play, PlusCircle, - Square, Type, Volume2, X, @@ -29,6 +28,7 @@ import { Pressable, ScrollView, StyleSheet, + Switch, Text as RNText, TextInput, type TextStyle, @@ -46,6 +46,7 @@ import { useQueryClient } from '@tanstack/react-query'; import { Icon } from '@/components/ui/icon'; import { Text } from '@/components/ui/text'; import { ScreenHeader } from '@/components/screen-header'; +import { resolveApiMediaUrl } from '@/core/api/media-url'; import { useAppSettings } from '@/hooks/use-app-settings'; import { useThemeColors } from '@/hooks/use-theme-colors'; import { useTypography } from '@/core/typography-context'; @@ -53,6 +54,8 @@ import { useMessages, useRealtimeSession } from '@/features/conversation/hooks'; import type { TtsSegmentPayload } from '@/features/conversation/realtime-session'; import type { TopicSuggestion } from '@/core/ws/types'; import { conversationKeys } from '@/features/conversation/query-keys'; +import { useSession } from '@/features/auth/hooks'; +import { useProfile } from '@/features/profile/hooks'; import { assistantSegmentMessageId, splitMessageParts, @@ -84,17 +87,94 @@ const CHAT_COLORS = { /** 与 archive/app-ios-react-app 中 `app-android/.../drawable/avatar_assistant.png` 同源(岁月知己) */ const AGENT_AVATAR = require('@/assets/images/avatar-assistant.png'); -const USER_AVATAR = - 'https://lh3.googleusercontent.com/aida-public/AB6AXuAMCjDBVhsUUXRAz9AGYejbTGoEYhzyiggYt_QIFqHCc3odRcBPNRhsE2Klg7gOeOV9V_qOy5qPqjU0GmpfgjGAWKGXZCizwRVz96N0n1IFMx4JH7QwV81zQsaVvCdJct_uABUBEawhncvQcbl0jUt_EUlNgzB-gIgUS_oLlT1TtRb8S5s7sAqwLRdGBa61yxL1X1iSWSFIn5N-WPIDs_vpCgS47q9SQjkT1q7VKvPzHzTiGF1bwVvjB7Bl2JgtaIUj6rkwlLbPG6xb'; + +function UserChatAvatar({ + uri, + letter, + meLabel, +}: { + uri: string | null; + letter: string; + meLabel: string; +}) { + if (uri) { + return ( + {meLabel} + ); + } + return ( + + + {letter} + + + ); +} type InputMode = 'text' | 'voice'; +/** 多段拆条后与后端 `ttsAudioUrls` 下标对齐 */ +function assistantBubbleSegmentIndex( + item: MessageItem, + listKey: string, +): number { + const part = /_part_(\d+)$/.exec(listKey); + if (part) return Number(part[1]); + const seg = /_seg_(\d+)$/.exec(item.id); + if (seg) return Number(seg[1]); + return 0; +} + +/** 按需 TTS 需要的落库助手消息 id */ +function durableAssistantIdForBubble( + item: MessageItem, + conversationId: string, +): string | null { + if (item.durableMessageId) return item.durableMessageId; + const m = /^(.+)_seg_\d+$/.exec(item.id); + if (m) return m[1]!; + const sid = item.id; + if (sid.startsWith(`${conversationId}_msg_`) || sid.startsWith('pending')) { + return null; + } + if ( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test(sid) + ) { + return sid; + } + return null; +} + +function segmentTtsUrlAt( + ttsAudioUrls: string[] | undefined, + segmentIndex: number, +): string | null { + if (!ttsAudioUrls?.length) return null; + const u = ttsAudioUrls[segmentIndex]; + return u && u.trim() ? u.trim() : null; +} + /** 流式助手区与自动 TTS 的 `PlaybackItem.messageRef.listKey` 对齐,用于点区域停止朗读 */ const TTS_STREAMING_LIST_KEY = '__tts_streaming__'; -/** 多段拆分后仅首段显示「朗读」控件(整段消息共用 `ttsAudioUrls`) */ -function isFirstAssistantTextPart(listKey: string, messageId: string): boolean { - return listKey === messageId || listKey === `${messageId}_part_0`; +/** PlaybackItem.messageRef.listKey 可与 `item.id` 或 `${id}_seg_/part_` 后缀对齐 */ +function playbackMessageRefMatchesMessage( + playbackListKey: string | undefined, + messageItemId: string, +): boolean { + if (!playbackListKey?.length) return false; + if (playbackListKey === messageItemId) return true; + return playbackListKey.startsWith(`${messageItemId}_`); } /** 展平消息列表:assistant 消息按 [SPLIT] 边界拆成多条,每条一个 listKey */ @@ -141,14 +221,22 @@ function MessageBubble({ currentPlaybackUri, currentPlaybackItem, playbackIsPlaying, + playbackIsPaused, onPlayVoiceExclusive, - onPausePlayback, + onPauseAssistantTts, + onResumeAssistantTts, onInterruptAssistantTts, onReplayAssistantTts, bubbleTextStyle, voiceDurationTextStyle, readAloudIconSize, readAloudButtonSize, + userAvatarUri, + userAvatarLetter, + conversationId, + assistantSegmentIndex, + durableAssistantId, + requestAssistantSegmentTts, }: { item: MessageItem; listKey: string; @@ -157,36 +245,70 @@ function MessageBubble({ currentPlaybackUri: string | null; currentPlaybackItem: PlaybackItem | null; playbackIsPlaying: boolean; + playbackIsPaused: boolean; onPlayVoiceExclusive: (uri: string) => void; - onPausePlayback: () => void; + onPauseAssistantTts: () => void; + onResumeAssistantTts: () => void; onInterruptAssistantTts: () => void; onReplayAssistantTts: (messageId: string, urls: string[]) => void; bubbleTextStyle?: TextStyle; voiceDurationTextStyle?: TextStyle; readAloudIconSize: number; readAloudButtonSize: number; + userAvatarUri: string | null; + userAvatarLetter: string; + conversationId: string; + assistantSegmentIndex: number; + durableAssistantId: string | null; + requestAssistantSegmentTts: (body: { + assistantMessageId: string; + segmentIndex: number; + segmentText?: string; + }) => boolean; }) { const { t } = useTranslation('conversation'); const isUser = item.senderType === 'user'; const isVoice = isVoiceMessage(item); - const ttsUrls = - Array.isArray(item.ttsAudioUrls) && item.ttsAudioUrls.length > 0 - ? item.ttsAudioUrls.filter( - (u): u is string => typeof u === 'string' && u.trim().length > 0, - ) - : []; + const ttsUrlThisPart = segmentTtsUrlAt( + item.ttsAudioUrls, + assistantSegmentIndex, + ); - const isAssistantTextFirstPart = - !isUser && !isVoice && isFirstAssistantTextPart(listKey, item.id); - - const isThisBubbleTtsTarget = + const playbackKind = currentPlaybackItem?.kind; + const playbackRefListKey = currentPlaybackItem?.messageRef?.listKey; + const matchesThisMessageForTts = !isUser && !isVoice && - playbackIsPlaying && - currentPlaybackItem?.kind !== 'voice' && - currentPlaybackItem?.messageRef?.listKey === item.id; + playbackKind !== 'voice' && + (playbackRefListKey === listKey || + playbackMessageRefMatchesMessage(playbackRefListKey, item.id)); - const isAssistantTtsHighlight = isThisBubbleTtsTarget; + const playbackEngaged = playbackIsPlaying || playbackIsPaused; + const isThisBubbleActiveTts = matchesThisMessageForTts && playbackEngaged; + + const isThisBubbleTtsPlaying = isThisBubbleActiveTts && playbackIsPlaying; + const isThisBubbleTtsPaused = isThisBubbleActiveTts && playbackIsPaused; + + const isAssistantTtsHighlight = isThisBubbleActiveTts; + + const isThisVoiceTrack = + !!item.audioUri && + currentPlaybackUri === item.audioUri && + currentPlaybackItem?.kind === 'voice'; + + const readAloudAccessibilityLabel = isThisBubbleTtsPlaying + ? t('readAloudPause') + : isThisBubbleTtsPaused + ? t('readAloudResume') + : ttsUrlThisPart + ? t('readAloudAgain') + : t('readAloudRequest'); + + const ReadAloudIconComponent = isThisBubbleTtsPlaying + ? Pause + : isThisBubbleTtsPaused + ? Play + : Volume2; const assistantTextBubbleBody = ( - {isAssistantTextFirstPart && - (ttsUrls.length > 0 || isThisBubbleTtsTarget) ? ( + {!isUser && !isVoice ? ( { - onReplayAssistantTts(item.id, ttsUrls); + if (isThisBubbleTtsPlaying) { + onPauseAssistantTts(); + } else if (isThisBubbleTtsPaused) { + onResumeAssistantTts(); + } else if (ttsUrlThisPart) { + onReplayAssistantTts(listKey, [ttsUrlThisPart]); + } else if (durableAssistantId) { + const ok = requestAssistantSegmentTts({ + assistantMessageId: durableAssistantId, + segmentIndex: assistantSegmentIndex, + segmentText: item.content, + }); + if (!ok) { + Alert.alert('', t('readAloudRequestFailed')); + } + } else { + Alert.alert('', t('readAloudNoMessageId')); + } }} style={({ pressed }) => [ styles.readAloudButton, { width: readAloudButtonSize, height: readAloudButtonSize }, - !isThisBubbleTtsTarget && pressed ? { opacity: 0.85 } : null, + pressed ? { opacity: 0.85 } : null, ]} - accessibilityElementsHidden={isThisBubbleTtsTarget} - importantForAccessibility={ - isThisBubbleTtsTarget ? 'no-hide-descendants' : 'auto' - } accessibilityRole="button" - accessibilityLabel={t('readAloudAgain')} - accessibilityState={{ disabled: isThisBubbleTtsTarget }} + accessibilityLabel={readAloudAccessibilityLabel} > @@ -241,13 +373,21 @@ function MessageBubble({ isUser ? styles.avatarWrapperUser : styles.avatarWrapperAgent, ]} > - {isUser + {isUser ? ( + + ) : ( + {agentName} + )} {isVoice ? ( @@ -261,19 +401,19 @@ function MessageBubble({ durationSeconds={item.durationSeconds ?? 0} audioUri={item.audioUri} isUser={isUser} - isPlaying={ - !!item.audioUri && - playbackIsPlaying && - currentPlaybackUri === item.audioUri - } + isPlaying={playbackIsPlaying && isThisVoiceTrack} durationTextStyle={voiceDurationTextStyle} onPlayPress={() => { if (!item.audioUri) return; - if (playbackIsPlaying && currentPlaybackUri === item.audioUri) { - onPausePlayback(); - } else { - onPlayVoiceExclusive(item.audioUri); + if (playbackIsPlaying && isThisVoiceTrack) { + onPauseAssistantTts(); + return; } + if (playbackIsPaused && isThisVoiceTrack) { + onResumeAssistantTts(); + return; + } + onPlayVoiceExclusive(item.audioUri); }} /> @@ -288,7 +428,7 @@ function MessageBubble({ ) : ( {assistantTextBubbleBody} - {isThisBubbleTtsTarget ? ( + {isThisBubbleActiveTts ? ( [ @@ -999,6 +1139,18 @@ export default function ConversationScreen() { const { t: tApp } = useTranslation('app'); const typography = useTypography(); const { largeText } = useAppSettings(); + const { user } = useSession(); + const { data: profile } = useProfile(); + + const userAvatarUri = useMemo( + () => resolveApiMediaUrl(user?.avatar_url ?? profile?.avatar_url ?? null), + [user?.avatar_url, profile?.avatar_url], + ); + const userAvatarLetter = useMemo(() => { + const nick = user?.nickname ?? profile?.nickname; + const c = nick?.trim().charAt(0); + return c ? c.toUpperCase() : '?'; + }, [user?.nickname, profile?.nickname]); /** 大字模式:对话气泡与输入使用更大一档,与设置中的「大字」一致 */ const chatBubbleTextStyle = useMemo( @@ -1038,6 +1190,14 @@ export default function ConversationScreen() { }), [typography, largeText], ); + const headerTtsSwitchLabelStyle = useMemo( + () => ({ + fontSize: largeText ? typography.bodySmall : typography.captionLarge, + fontWeight: '600' as const, + color: CHAT_COLORS.primary, + }), + [typography, largeText], + ); const inputLineHeight = largeText ? typography.lineHeightLoose : typography.lineHeightNormal; @@ -1099,7 +1259,13 @@ export default function ConversationScreen() { ); const { data: messages } = useMessages(id); + const hasAssistantInHistory = useMemo( + () => (messages ?? []).some((m) => m.senderType === 'assistant'), + [messages], + ); + const ttsGate = useRef(createTtsPlaybackGate()); + const lastUserMessageRequestedTtsRef = useRef(false); const { enqueue, enqueueExclusive, @@ -1107,6 +1273,8 @@ export default function ConversationScreen() { status: playerStatus, currentSource, currentPlaybackItem, + pausePlayback, + resumePlayback, } = usePlayer(); const handleTtsPlaybackResume = useCallback(() => { @@ -1152,7 +1320,16 @@ export default function ConversationScreen() { const prevUrls = target.ttsAudioUrls ?? []; if (prevUrls.includes(cosUrl)) return old; const segmentBind = p.assistantMessageId != null && p.index != null; - const nextUrls = segmentBind ? [cosUrl] : [...prevUrls, cosUrl]; + let nextUrls: string[]; + if (segmentBind) { + const slot = p.index!; + nextUrls = [...prevUrls]; + while (nextUrls.length <= slot) nextUrls.push(''); + if (nextUrls[slot] === cosUrl) return old; + nextUrls[slot] = cosUrl; + } else { + nextUrls = [...prevUrls, cosUrl]; + } const nextId = p.assistantMessageId && p.index == null && @@ -1164,6 +1341,7 @@ export default function ConversationScreen() { next[idx] = { ...target, id: nextId, + durableMessageId: p.assistantMessageId ?? target.durableMessageId, ttsAudioUrls: nextUrls, }; return next; @@ -1171,6 +1349,10 @@ export default function ConversationScreen() { ); } + const shouldEnqueue = + p.manual === true || lastUserMessageRequestedTtsRef.current; + if (!shouldEnqueue) return; + const listKey = p.assistantMessageId != null && p.index != null ? assistantSegmentMessageId(p.assistantMessageId, p.index) @@ -1199,9 +1381,13 @@ export default function ConversationScreen() { [enqueueExclusive], ); - const handlePausePlayback = useCallback(() => { - void stop(); - }, [stop]); + const handlePauseAssistantPlayback = useCallback(() => { + pausePlayback(); + }, [pausePlayback]); + + const handleResumeAssistantPlayback = useCallback(() => { + void resumePlayback(); + }, [resumePlayback]); const handleReplayAssistantTts = useCallback( (messageId: string, urls: string[]) => { @@ -1230,11 +1416,15 @@ export default function ConversationScreen() { sendText, sendVoiceMessage, sendTtsCancel, + requestAssistantSegmentTts, } = useRealtimeSession({ conversationId: id ?? '', enabled: !!id, onTtsSegment: handleTtsSegment, onTtsPlaybackResume: handleTtsPlaybackResume, + onUserSendTtsPreference: (r) => { + lastUserMessageRequestedTtsRef.current = r; + }, }); const handleInterruptAssistantTts = useCallback(() => { @@ -1253,14 +1443,20 @@ export default function ConversationScreen() { const [input, setInput] = useState(''); const [inputResetKey, setInputResetKey] = useState(0); + /** 本条发出的用户消息是否请求助手朗读(先 TTS 再出字) */ + const [ttsThisTurn, setTtsThisTurn] = useState(false); const [inputMode, setInputMode] = useState('text'); const [isKeyboardVisible, setIsKeyboardVisible] = useState(false); + const inputModeRef = useRef('text'); const listRef = useRef(null); const textInputRef = useRef(null); /** 底部输入区(含连接提示 + 输入条)高度,用于多行输入增高时把列表滚到底,避免挡住最新消息 */ const composerBlockHeightRef = useRef(null); /** 连接中(connecting)时点发送:排队,连上后自动发出 */ - const pendingTextSendRef = useRef(null); + const pendingTextSendRef = useRef<{ + text: string; + ttsThisTurn: boolean; + } | null>(null); const connectingSendTimeoutRef = useRef | null>( null, ); @@ -1318,8 +1514,8 @@ export default function ConversationScreen() { const handleStopRecording = useCallback(async () => { const result = await stopRecording(); if (!result) return; - void sendVoiceMessage(result.uri, result.durationMs); - }, [stopRecording, sendVoiceMessage]); + void sendVoiceMessage(result.uri, result.durationMs, { ttsThisTurn }); + }, [stopRecording, sendVoiceMessage, ttsThisTurn]); const scrollListToEndAfterComposerLayout = useCallback(() => { InteractionManager.runAfterInteractions(() => { @@ -1334,7 +1530,16 @@ export default function ConversationScreen() { * iOS:WillShow 提前标记键盘区,便于底部 inset 与动画同步。 */ useEffect(() => { + inputModeRef.current = inputMode; + }, [inputMode]); + + useEffect(() => { + const onKeyboardWillShow = () => { + if (inputModeRef.current !== 'text') return; + setIsKeyboardVisible(true); + }; const onKeyboardShown = () => { + if (inputModeRef.current !== 'text') return; setIsKeyboardVisible(true); scrollListToEndAfterComposerLayout(); }; @@ -1343,11 +1548,7 @@ export default function ConversationScreen() { }; const subs: ReturnType[] = []; if (Platform.OS === 'ios') { - subs.push( - Keyboard.addListener('keyboardWillShow', () => - setIsKeyboardVisible(true), - ), - ); + subs.push(Keyboard.addListener('keyboardWillShow', onKeyboardWillShow)); } subs.push(Keyboard.addListener('keyboardDidShow', onKeyboardShown)); subs.push( @@ -1361,6 +1562,7 @@ export default function ConversationScreen() { const handleInputModeToggle = useCallback(() => { if (inputMode === 'voice') { + inputModeRef.current = 'text'; setInputMode('text'); return; } @@ -1369,10 +1571,13 @@ export default function ConversationScreen() { /** * `dismiss()` 在当前库版本里会等 `keyboardDidHide` 后才 resolve。 * 这里不能 `await`,否则事件丢失/延迟时会把模式切换卡住。 - * 先切 UI,再让键盘异步收起;外层 AvoidingView 会在键盘真正隐藏前保持启用。 + * 切语音时输入区必须立刻退出键盘避让,否则隐藏事件延迟/丢失会把底部栏卡在键盘高度。 */ + inputModeRef.current = 'voice'; textInputRef.current?.blur(); + setIsKeyboardVisible(false); setInputMode('voice'); + Keyboard.dismiss(); void KeyboardController.dismiss({ animated: false }); }, []); @@ -1401,13 +1606,18 @@ export default function ConversationScreen() { /** 连接中排队:一旦变为 connected,短延迟后发送,给 WebSocket onopen 一点时间 */ useEffect(() => { if (connectionState !== 'connected') return; - const text = pendingTextSendRef.current; - if (!text) return; + const pending = pendingTextSendRef.current; + if (!pending) return; const t = setTimeout(() => { - if (pendingTextSendRef.current !== text) return; + if ( + pendingTextSendRef.current?.text !== pending.text || + pendingTextSendRef.current?.ttsThisTurn !== pending.ttsThisTurn + ) { + return; + } pendingTextSendRef.current = null; clearConnectingSendTimeout(); - sendText(text); + sendText(pending.text, { ttsThisTurn: pending.ttsThisTurn }); }, PENDING_SEND_FLUSH_MS); return () => clearTimeout(t); }, [connectionState, sendText, clearConnectingSendTimeout]); @@ -1416,10 +1626,10 @@ export default function ConversationScreen() { useEffect(() => { if (connectionState !== 'disconnected') return; clearConnectingSendTimeout(); - const text = pendingTextSendRef.current; - if (!text) return; + const pending = pendingTextSendRef.current; + if (!pending) return; pendingTextSendRef.current = null; - setInput(text); + setInput(pending.text); }, [connectionState, clearConnectingSendTimeout]); const handleSend = () => { @@ -1430,21 +1640,26 @@ export default function ConversationScreen() { return; } if (connectionState === 'connecting') { - pendingTextSendRef.current = text; + pendingTextSendRef.current = { text, ttsThisTurn }; setInput(''); setInputResetKey((k) => k + 1); scheduleRefocusComposer(); clearConnectingSendTimeout(); connectingSendTimeoutRef.current = setTimeout(() => { connectingSendTimeoutRef.current = null; - if (pendingTextSendRef.current !== text) return; + if ( + pendingTextSendRef.current?.text !== text || + pendingTextSendRef.current?.ttsThisTurn !== ttsThisTurn + ) { + return; + } pendingTextSendRef.current = null; setInput(text); Alert.alert(t('chatUnavailableTitle'), t('chatQueueSendTimeout')); }, CONNECTING_SEND_TIMEOUT_MS); return; } - sendText(text); + sendText(text, { ttsThisTurn }); setInput(''); setInputResetKey((k) => k + 1); scheduleRefocusComposer(); @@ -1492,7 +1707,9 @@ export default function ConversationScreen() { ? t('connectionConnecting') : t('connectionDisconnected'); const showConnectionBadge = __DEV__; - const showConnectionNotice = connectionState !== 'connected'; + const showConnectionNotice = + connectionState !== 'connected' && + !(connectionState === 'connecting' && hasAssistantInHistory); const connectionNoticeText = connectionState === 'connecting' ? t('chatUnavailableConnecting') @@ -1507,12 +1724,14 @@ export default function ConversationScreen() { variant="chat" title={ - {tApp('name')} - + {showConnectionBadge ? ( } backAccessibilityLabel={t('chatTitle')} + right={ + + + {t('ttsThisTurn')} + + + + } /> {/* Message list - flex 1, takes remaining space */} @@ -1556,14 +1793,25 @@ export default function ConversationScreen() { currentPlaybackUri={currentSource} currentPlaybackItem={currentPlaybackItem} playbackIsPlaying={playerStatus === 'playing'} + playbackIsPaused={playerStatus === 'paused'} + onPauseAssistantTts={handlePauseAssistantPlayback} + onResumeAssistantTts={handleResumeAssistantPlayback} onPlayVoiceExclusive={handlePlayVoiceExclusive} - onPausePlayback={handlePausePlayback} onInterruptAssistantTts={handleInterruptAssistantTts} onReplayAssistantTts={handleReplayAssistantTts} bubbleTextStyle={chatBubbleTextStyle} voiceDurationTextStyle={chatVoiceDurationStyle} readAloudIconSize={chatReadAloudIconSize} readAloudButtonSize={chatReadAloudButtonSize} + userAvatarUri={userAvatarUri} + userAvatarLetter={userAvatarLetter} + conversationId={id ?? ''} + assistantSegmentIndex={assistantBubbleSegmentIndex( + item, + item.listKey, + )} + durableAssistantId={durableAssistantIdForBubble(item, id ?? '')} + requestAssistantSegmentTts={requestAssistantSegmentTts} /> )} onContentSizeChange={() => @@ -1586,7 +1834,7 @@ export default function ConversationScreen() { agentName={t('agentName')} streamingTtsActive={ !!streamingMessage && - playerStatus === 'playing' && + (playerStatus === 'playing' || playerStatus === 'paused') && currentPlaybackItem?.kind === 'tts_auto' } onStreamingPress={handleInterruptAssistantTts} @@ -1692,6 +1940,12 @@ const styles = StyleSheet.create({ alignItems: 'center', gap: 10, }, + headerTtsRow: { + flexDirection: 'row', + alignItems: 'center', + gap: 6, + flexShrink: 0, + }, headerTitle: { fontWeight: '700', color: CHAT_COLORS.primary, @@ -1755,6 +2009,16 @@ const styles = StyleSheet.create({ backgroundColor: CHAT_COLORS.primaryFixed, borderColor: 'rgba(141, 140, 144, 0.35)', }, + userAvatarFallback: { + alignItems: 'center', + justifyContent: 'center', + backgroundColor: CHAT_COLORS.primaryFixed, + }, + userAvatarFallbackText: { + fontSize: 17, + fontWeight: '600', + color: CHAT_COLORS.onSurfaceVariant, + }, /** 远程图在 Android 上若用 100% 尺寸可能解析为 0×0,需写死数值(见 Expo 文档) */ avatarImage: { width: 40, diff --git a/app-expo/src/app/(main)/personal-info.tsx b/app-expo/src/app/(main)/personal-info.tsx index 2aa2866..342ef7d 100644 --- a/app-expo/src/app/(main)/personal-info.tsx +++ b/app-expo/src/app/(main)/personal-info.tsx @@ -1,24 +1,73 @@ +import { Image } from 'expo-image'; +import * as ImagePicker from 'expo-image-picker'; +import { router } from 'expo-router'; import React, { useEffect, useState } from 'react'; -import { View } from 'react-native'; -import { SafeAreaView } from 'react-native-safe-area-context'; +import { useTranslation } from 'react-i18next'; +import { + ActivityIndicator, + Alert, + Dimensions, + Keyboard, + KeyboardAvoidingView, + Modal, + Platform, + Pressable, + ScrollView, + View, +} from 'react-native'; +import { + SafeAreaView, + useSafeAreaInsets, +} from 'react-native-safe-area-context'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; import { Text } from '@/components/ui/text'; import { ScreenHeader } from '@/components/screen-header'; +import { resolveApiMediaUrl } from '@/core/api/media-url'; +import { ApiError } from '@/core/api/types'; +import { buildAvatarUploadFormData } from '@/features/auth/avatar-upload-form-data'; +import { + useAvatarPresets, + useSetAvatarPreset, + useUpdateNickname, + useUploadAvatar, +} from '@/features/auth/hooks'; import { useProfile, useUpdateProfile } from '@/features/profile/hooks'; -export default function PersonalInfoScreen() { - const { data: profile } = useProfile(); - const update = useUpdateProfile(); +const PRESET_GRID_H_PADDING = 16 * 2; +const TILE_GAP = 12; +const COLS = 4; +function computePresetTileSize(): number { + const w = Dimensions.get('window').width; + return (w - PRESET_GRID_H_PADDING - TILE_GAP * (COLS - 1)) / COLS; +} + +type AvatarModalStep = 'menu' | 'presets'; + +export default function PersonalInfoScreen() { + const { t } = useTranslation('profile'); + const insets = useSafeAreaInsets(); + const { data: profile, isLoading: profileLoading } = useProfile(); + const update = useUpdateProfile(); + const updateNicknameMut = useUpdateNickname(); + const uploadAvatar = useUploadAvatar(); + const setPreset = useSetAvatarPreset(); + const { data: presets, isLoading: presetsLoading } = useAvatarPresets(); + + const [nickname, setNickname] = useState(''); const [birthYear, setBirthYear] = useState(''); const [birthPlace, setBirthPlace] = useState(''); const [grewUpPlace, setGrewUpPlace] = useState(''); const [occupation, setOccupation] = useState(''); + const [avatarModalOpen, setAvatarModalOpen] = useState(false); + const [avatarStep, setAvatarStep] = useState('menu'); + useEffect(() => { if (profile) { + setNickname(profile.nickname ?? ''); setBirthYear(profile.birth_year?.toString() ?? ''); setBirthPlace(profile.birth_place ?? ''); setGrewUpPlace(profile.grew_up_place ?? ''); @@ -26,53 +75,306 @@ export default function PersonalInfoScreen() { } }, [profile]); - const handleSave = () => { - update.mutate({ - birth_year: birthYear ? Number(birthYear) : null, - birth_place: birthPlace || null, - grew_up_place: grewUpPlace || null, - occupation: occupation || null, - }); + const closeAvatarModal = () => { + setAvatarModalOpen(false); + setAvatarStep('menu'); }; + const pickFromLibrary = async () => { + const perm = await ImagePicker.requestMediaLibraryPermissionsAsync(); + if (!perm.granted) { + Alert.alert('', t('personalInfo.libraryPermissionDenied')); + return; + } + + const result = await ImagePicker.launchImageLibraryAsync({ + mediaTypes: ['images'], + allowsEditing: true, + aspect: [1, 1], + quality: 0.9, + }); + + if (result.canceled) return; + + const asset = result.assets[0]; + + try { + const form = await buildAvatarUploadFormData(asset); + await uploadAvatar.mutateAsync(form); + closeAvatarModal(); + } catch (err) { + const msg = + err instanceof ApiError + ? err.message + : err instanceof Error + ? err.message + : String(err); + Alert.alert(t('personalInfo.avatarUploadFailed'), msg); + } + }; + + const applyPreset = async (presetId: string) => { + try { + await setPreset.mutateAsync(presetId); + closeAvatarModal(); + } catch (err) { + const msg = + err instanceof ApiError + ? err.message + : err instanceof Error + ? err.message + : String(err); + Alert.alert(t('personalInfo.avatarPresetFailed'), msg); + } + }; + + const avatarBusy = uploadAvatar.isPending || setPreset.isPending; + const avatarUri = resolveApiMediaUrl(profile?.avatar_url ?? null); + const tileSize = computePresetTileSize(); + + const handleSave = async () => { + const trimmed = nickname.trim(); + if (!trimmed) { + Alert.alert('', t('personalInfo.nicknameRequired')); + return; + } + + let nicknameCommitted = false; + try { + if (profile && trimmed !== profile.nickname) { + await updateNicknameMut.mutateAsync({ nickname: trimmed }); + nicknameCommitted = true; + } + await update.mutateAsync({ + birth_year: birthYear ? Number(birthYear) : null, + birth_place: birthPlace || null, + grew_up_place: grewUpPlace || null, + occupation: occupation || null, + }); + Keyboard.dismiss(); + if (router.canGoBack()) { + router.back(); + } else { + router.replace('/(tabs)/profile'); + } + } catch (err) { + const msg = + err instanceof ApiError + ? err.message + : err instanceof Error + ? err.message + : String(err); + if (nicknameCommitted) { + Alert.alert( + t('personalInfo.savePartialTitle'), + `${t('personalInfo.savePartialBody')}\n\n${msg}`, + ); + } else { + Alert.alert(t('personalInfo.saveFailed'), msg); + } + } + }; + + const saving = update.isPending || updateNicknameMut.isPending; + const keyboardVerticalOffset = Platform.OS === 'ios' ? insets.top : 0; + return ( - - - - - - - - + + + + + + + { + setAvatarStep('menu'); + setAvatarModalOpen(true); + }} + > + + {avatarUri ? ( + + ) : ( + + + {nickname.trim().slice(0, 1).toUpperCase() || '?'} + + + )} + {avatarBusy ? ( + + + + ) : null} + + + + {t('personalInfo.changeAvatar')} + + - {update.error && ( - - {update.error.message} - - )} + + + {t('personalInfo.nickname')} + + + + + + + - - + {(update.error ?? updateNicknameMut.error) != null ? ( + + {(updateNicknameMut.error ?? update.error)?.message} + + ) : null} + + + + + + + + + + + {avatarStep === 'presets' ? ( + setAvatarStep('menu')} + > + {t('personalInfo.back')} + + ) : ( + + )} + + {avatarStep === 'presets' + ? t('personalInfo.presetPickTitle') + : t('personalInfo.changeAvatar')} + + + {t('personalInfo.cancel')} + + + + {avatarStep === 'menu' ? ( + + + + + ) : ( + + {presetsLoading ? ( + + ) : ( + + {(presets ?? []).map((item) => { + const uri = resolveApiMediaUrl(item.url); + return ( + void applyPreset(item.id)} + style={{ + width: tileSize, + height: tileSize, + }} + > + {uri ? ( + + ) : null} + + ); + })} + + )} + + )} + + ); } diff --git a/app-expo/src/app/(tabs)/index.tsx b/app-expo/src/app/(tabs)/index.tsx index ff28f30..ff5235b 100644 --- a/app-expo/src/app/(tabs)/index.tsx +++ b/app-expo/src/app/(tabs)/index.tsx @@ -22,6 +22,12 @@ import { Text } from '@/components/ui/text'; import { NetworkError } from '@/core/api/types'; import { useTypography } from '@/core/typography-context'; import { conversationApi } from '@/features/conversation/api'; +import { + prefetchConversationMessages, + prewarmConversationSession, + warmupConversationOpening, +} from '@/features/conversation/entry-warmup'; +import { abandonPreparedRealtimeSession } from '@/features/conversation/prepared-session-registry'; import { useConversations, useCreateConversation, @@ -88,9 +94,13 @@ function GreetingCardSkeleton() { function ConversationCard({ item, onPress, + onPressIn, + disabled, }: { item: ConversationListItem; onPress: () => void; + onPressIn?: () => void; + disabled?: boolean; }) { const { t } = useTranslation('conversation'); const typography = useTypography(); @@ -140,6 +150,8 @@ function ConversationCard({ {renderAvatar()} @@ -177,9 +189,13 @@ function ConversationCard({ function SwipeableConversationCard({ item, onPress, + onPressIn, + disabled, }: { item: ConversationListItem; onPress: () => void; + onPressIn?: () => void; + disabled?: boolean; }) { const { t } = useTranslation('conversation'); const deleteConversation = useDeleteConversation(); @@ -215,7 +231,12 @@ function SwipeableConversationCard({ return ( - + ); } @@ -237,26 +258,34 @@ function conversationStartedAtMs(item: ConversationListItem): number { return item.startedAt ?? item.latestMessageTime; } +function conversationHasAnyMessage(item: ConversationListItem): boolean { + const preview = + typeof item.latestMessagePreview === 'string' + ? item.latestMessagePreview + : ''; + return item.hasUserMessage || preview.trim().length > 0; +} + function msUntilNextLocalMidnight(nowMs: number): number { const next = new Date(nowMs); next.setHours(24, 0, 0, 0); return Math.max(1, next.getTime() - nowMs); } -/** 仅复用「当天创建」且尚无用户消息的对话,跨日则新开(一天一次招呼会话) */ +/** 仅复用「当天创建」且完全空白的对话;AI 已开场后应进入「继续对话」 */ function findReusableEmptyConversationId( items: ConversationListItem[], nowMs: number = Date.now(), ): string | null { const found = items.find( (c) => - c.hasUserMessage === false && + !conversationHasAnyMessage(c) && isSameLocalCalendarDay(conversationStartedAtMs(c), nowMs), ); return found?.id ?? null; } -/** 「继续对话」:仅进入当天创建且已有用户消息的对话(列表已按最近活动排序) */ +/** 「继续对话」:当天会话只要已有 AI/用户任意消息即可继续(列表已按最近活动排序) */ function findTodayConversationToResume( items: ConversationListItem[], nowMs: number = Date.now(), @@ -264,7 +293,7 @@ function findTodayConversationToResume( return ( items.find( (c) => - c.hasUserMessage && + conversationHasAnyMessage(c) && isSameLocalCalendarDay(conversationStartedAtMs(c), nowMs), ) ?? null ); @@ -277,6 +306,7 @@ export default function ConversationsScreen() { const { data: conversations = [], isLoading } = useConversations(); const createConversation = useCreateConversation(); const createOnceGuardRef = useRef(false); + const [isEnteringChat, setIsEnteringChat] = useState(false); const [nowMs, setNowMs] = useState(() => Date.now()); const isEmpty = conversations.length === 0; @@ -311,8 +341,47 @@ export default function ConversationsScreen() { }; }, []); + const navigateToConversation = async ( + conversationId: string, + needsOpeningWarmup: boolean, + ) => { + if (!needsOpeningWarmup) { + // 已有消息:立即跳转,列表/按下时已预热 WS 与缓存,避免阻塞用户感知 + prewarmConversationSession(queryClient, conversationId); + void prefetchConversationMessages(queryClient, conversationId); + router.push(`/(main)/conversation/${conversationId}`); + return; + } + setIsEnteringChat(true); + try { + await warmupConversationOpening(queryClient, conversationId); + router.push(`/(main)/conversation/${conversationId}`); + } catch (err) { + abandonPreparedRealtimeSession(conversationId); + const msg = + err instanceof NetworkError + ? t('createError') + : err instanceof Error + ? err.message + : t('createError'); + Alert.alert(t('chatTitle'), msg, [{ text: t('confirm') }]); + } finally { + setIsEnteringChat(false); + } + }; + + /** 用户按下卡片即开始预热(WS + 消息缓存),抬起手指时大概率已就绪 */ + const handleConversationPressIn = (item: ConversationListItem) => { + if (!conversationHasAnyMessage(item)) return; + prewarmConversationSession(queryClient, item.id); + }; + const handleCreateConversation = () => { - if (createConversation.isPending || createOnceGuardRef.current) { + if ( + createConversation.isPending || + createOnceGuardRef.current || + isEnteringChat + ) { return; } createOnceGuardRef.current = true; @@ -325,8 +394,11 @@ export default function ConversationsScreen() { }); const reuseId = findReusableEmptyConversationId(fresh ?? []); if (reuseId) { - createOnceGuardRef.current = false; - router.push(`/(main)/conversation/${reuseId}`); + try { + await navigateToConversation(reuseId, true); + } finally { + createOnceGuardRef.current = false; + } return; } } catch { @@ -334,9 +406,12 @@ export default function ConversationsScreen() { } createConversation.mutate(undefined, { - onSuccess: (result) => { - createOnceGuardRef.current = false; - router.push(`/(main)/conversation/${result.id}`); + onSuccess: async (result) => { + try { + await navigateToConversation(result.id, true); + } finally { + createOnceGuardRef.current = false; + } }, onError: (err) => { createOnceGuardRef.current = false; @@ -357,17 +432,29 @@ export default function ConversationsScreen() { const handleResumeLatestConversation = () => { const toResume = findTodayConversationToResume(conversations, nowMs); if (toResume) { - router.push(`/(main)/conversation/${toResume.id}`); + void navigateToConversation(toResume.id, false); return; } // 当日没有可继续的会话(例如会话始于昨日):与「打个招呼」一致,复用当日空会话或新建 handleCreateConversation(); }; - const handleConversationPress = (id: string) => { - router.push(`/(main)/conversation/${id}`); + const handleConversationPress = (item: ConversationListItem) => { + void navigateToConversation(item.id, !conversationHasAnyMessage(item)); }; + /** + * 列表加载完成后,预热"今天可继续"的会话(用户最可能点)的 WS。 + * 单槽连接池:换会话会自动 dispose 旧槽,所以这里只挑一条最像即将被点的。 + */ + useEffect(() => { + if (isLoading) return; + const candidate = + todayConversation ?? conversations.find(conversationHasAnyMessage); + if (!candidate) return; + prewarmConversationSession(queryClient, candidate.id); + }, [isLoading, conversations, todayConversation, queryClient]); + return ( @@ -401,7 +488,7 @@ export default function ConversationsScreen() { @@ -424,6 +511,7 @@ export default function ConversationsScreen() { handleConversationPress(item.id)} + disabled={isEnteringChat} + onPress={() => handleConversationPress(item)} + onPressIn={() => handleConversationPressIn(item)} /> ))} diff --git a/app-expo/src/app/(tabs)/profile.tsx b/app-expo/src/app/(tabs)/profile.tsx index 9922d82..cef7346 100644 --- a/app-expo/src/app/(tabs)/profile.tsx +++ b/app-expo/src/app/(tabs)/profile.tsx @@ -1,4 +1,5 @@ import { router } from 'expo-router'; +import { Image } from 'expo-image'; import React from 'react'; import { Pressable, ScrollView, View } from 'react-native'; import { useTranslation } from 'react-i18next'; @@ -21,6 +22,7 @@ import { import { Icon } from '@/components/ui/icon'; import { Switch } from '@/components/ui/switch'; import { Text } from '@/components/ui/text'; +import { resolveApiMediaUrl } from '@/core/api/media-url'; import { useAppSettings } from '@/hooks/use-app-settings'; import { useSession, useLogout } from '@/features/auth/hooks'; import { useCurrentPlan } from '@/features/profile/hooks'; @@ -181,6 +183,8 @@ export default function ProfileScreen() { themeOptions.find((o) => o.value === themeName)?.label ?? tApp('theme.default'); + const avatarUri = resolveApiMediaUrl(user?.avatar_url ?? null); + return ( - + {avatarUri ? ( + + ) : ( + + )} svg]:px-3' })), - sm: cn('h-9 gap-1.5 rounded-md px-3 sm:h-8', Platform.select({ web: 'has-[>svg]:px-2.5' })), - lg: cn('h-11 rounded-md px-6 sm:h-10', Platform.select({ web: 'has-[>svg]:px-4' })), + default: cn('min-h-10 px-4 py-2 sm:min-h-9', Platform.select({ web: 'has-[>svg]:px-3' })), + sm: cn('min-h-9 gap-1.5 rounded-md px-3 sm:min-h-8', Platform.select({ web: 'has-[>svg]:px-2.5' })), + lg: cn('min-h-11 rounded-md px-6 sm:min-h-10', Platform.select({ web: 'has-[>svg]:px-4' })), icon: 'h-10 w-10 sm:h-9 sm:w-9', }, }, diff --git a/app-expo/src/core/api/media-url.ts b/app-expo/src/core/api/media-url.ts new file mode 100644 index 0000000..06982be --- /dev/null +++ b/app-expo/src/core/api/media-url.ts @@ -0,0 +1,14 @@ +import { config } from '@/core/config'; + +/** 将 API 返回的相对路径(如 `/api/auth/avatars/x.jpg`)转为可请求的绝对 URL。 */ +export function resolveApiMediaUrl( + pathOrUrl: string | null | undefined, +): string | null { + if (pathOrUrl == null || pathOrUrl === '') return null; + if (/^https?:\/\//i.test(pathOrUrl)) return pathOrUrl; + if (pathOrUrl.startsWith('/')) { + const base = config.apiBaseUrl.replace(/\/$/, ''); + return `${base}${pathOrUrl}`; + } + return pathOrUrl; +} diff --git a/app-expo/src/core/ws/client.ts b/app-expo/src/core/ws/client.ts index a4702cd..bd7cedd 100644 --- a/app-expo/src/core/ws/client.ts +++ b/app-expo/src/core/ws/client.ts @@ -11,6 +11,13 @@ import type { export type WsEventListener = (event: WsEvent) => void; export type WsStateListener = (state: WsConnectionState) => void; +function buildWsUrl(conversationId: string, token: string): string { + const baseUrl = config.wsBaseUrl.replace(/\/+$/u, ''); + const encodedConversationId = encodeURIComponent(conversationId); + const encodedToken = encodeURIComponent(token); + return `${baseUrl}/ws/conversation/${encodedConversationId}?token=${encodedToken}`; +} + function mapServerMessage(raw: RawServerMessage): WsEvent | null { const cid = raw.conversation_id; const d = raw.data; @@ -51,6 +58,7 @@ function mapServerMessage(raw: RawServerMessage): WsEvent | null { index: d.index as number | undefined, total: d.total as number | undefined, assistantMessageId: d.assistant_message_id as string | undefined, + manual: d.manual as boolean | undefined, }; case 'end_conversation': @@ -124,7 +132,7 @@ export class WsClient { return; } - const url = `${config.wsBaseUrl}/ws/conversation/${this.conversationId}?token=${token}`; + const url = buildWsUrl(this.conversationId, token); try { this.ws = new WebSocket(url); @@ -186,14 +194,37 @@ export class WsClient { return true; } - sendText(text: string): boolean { - return this.send({ type: 'text', data: { text } }); + sendText(text: string, opts?: { ttsThisTurn?: boolean }): boolean { + return this.send({ + type: 'text', + data: { + text, + ...(opts?.ttsThisTurn === true ? { tts_this_turn: true } : {}), + }, + }); } sendTtsCancel(): boolean { return this.send({ type: 'tts_cancel', data: {} }); } + sendTtsRequest(body: { + assistantMessageId: string; + segmentIndex: number; + segmentText?: string; + }): boolean { + return this.send({ + type: 'tts_request', + data: { + assistant_message_id: body.assistantMessageId, + segment_index: body.segmentIndex, + ...(body.segmentText != null && body.segmentText !== '' + ? { segment_text: body.segmentText } + : {}), + }, + }); + } + sendEndConversation(): boolean { return this.send({ type: 'end_conversation', data: {} }); } diff --git a/app-expo/src/core/ws/types.ts b/app-expo/src/core/ws/types.ts index 905759f..955747d 100644 --- a/app-expo/src/core/ws/types.ts +++ b/app-expo/src/core/ws/types.ts @@ -16,6 +16,7 @@ export type ClientMessageType = | 'audio_message' | 'transcribe_only' | 'tts_cancel' + | 'tts_request' | 'end_conversation'; export interface RawServerMessage { @@ -69,6 +70,8 @@ export interface TtsAudioReceivedEvent { total?: number; /** 持久化后的助手消息 id(与 REST `messages` 中 `id` 对齐) */ assistantMessageId?: string; + /** 用户点击喇叭按需请求时为 true,客户端应播放(与「仅朗读开关打开这一轮」的自动 TTS 区分) */ + manual?: boolean; } export interface ConversationEndedEvent { diff --git a/app-expo/src/features/auth/api.ts b/app-expo/src/features/auth/api.ts index f89ffc0..7b178cb 100644 --- a/app-expo/src/features/auth/api.ts +++ b/app-expo/src/features/auth/api.ts @@ -1,11 +1,13 @@ import { api } from '@/core/api/client'; import type { + AvatarPresetItem, ChangePasswordRequest, ChangePhoneRequest, LoginRequest, RegisterRequest, ResetPasswordRequest, + SetAvatarPresetRequest, SmsLoginRequest, SmsRegisterRequest, SmsRequest, @@ -90,4 +92,14 @@ export const authApi = { uploadAvatar(file: FormData) { return api.post(`${AUTH}/me/avatar`, { body: file }); }, + + fetchAvatarPresets() { + return api.get(`${AUTH}/avatar-presets`, { + skipAuth: true, + }); + }, + + setAvatarPreset(body: SetAvatarPresetRequest) { + return api.put(`${AUTH}/me/avatar/preset`, { body }); + }, } as const; diff --git a/app-expo/src/features/auth/avatar-upload-form-data.ts b/app-expo/src/features/auth/avatar-upload-form-data.ts new file mode 100644 index 0000000..e72d145 --- /dev/null +++ b/app-expo/src/features/auth/avatar-upload-form-data.ts @@ -0,0 +1,77 @@ +import type * as ImagePicker from 'expo-image-picker'; +import { Platform } from 'react-native'; + +type AvatarMime = 'image/jpeg' | 'image/png' | 'image/webp'; + +function inferMimeFromUri(uri: string): AvatarMime { + const u = uri.toLowerCase(); + if (u.endsWith('.png')) return 'image/png'; + if (u.endsWith('.webp')) return 'image/webp'; + return 'image/jpeg'; +} + +function coerceMime(value: string | null | undefined, uri: string): AvatarMime { + if ( + value === 'image/jpeg' || + value === 'image/png' || + value === 'image/webp' + ) { + return value; + } + return inferMimeFromUri(uri); +} + +function mimeToFilename(mime: AvatarMime): string { + switch (mime) { + case 'image/png': + return 'avatar.png'; + case 'image/webp': + return 'avatar.webp'; + default: + return 'avatar.jpg'; + } +} + +/** + * 构建与后端 `POST /api/auth/me/avatar` 约定的 multipart(字段名 `file`)。 + * Native:`{ uri, name, type }`;Web:`File`,避免 RN FormData 在 Web 上不识别 `uri`。 + */ +export async function buildAvatarUploadFormData( + asset: ImagePicker.ImagePickerAsset, +): Promise { + const uri = asset.uri; + const mime = coerceMime(asset.mimeType, uri); + const filename = mimeToFilename(mime); + const form = new FormData(); + + if (Platform.OS === 'web') { + const webFile = asset.file; + if ( + webFile instanceof File && + (webFile.type === 'image/jpeg' || + webFile.type === 'image/png' || + webFile.type === 'image/webp') + ) { + form.append( + 'file', + webFile, + webFile.name || mimeToFilename(coerceMime(webFile.type, uri)), + ); + return form; + } + + const res = await fetch(uri); + const blob = await res.blob(); + const type = coerceMime(blob.type, uri); + form.append('file', new File([blob], mimeToFilename(type), { type })); + return form; + } + + form.append('file', { + uri, + name: filename, + type: mime, + } as unknown as Blob); + + return form; +} diff --git a/app-expo/src/features/auth/hooks.ts b/app-expo/src/features/auth/hooks.ts index d4e3ee3..6bc217a 100644 --- a/app-expo/src/features/auth/hooks.ts +++ b/app-expo/src/features/auth/hooks.ts @@ -4,6 +4,7 @@ import { useCallback } from 'react'; import { AuthError } from '@/core/api/types'; import { tokenManager } from '@/core/auth/token-manager'; +import { disposeAllBackgroundConversationWs } from '@/features/conversation/conversation-ws-background-pool'; import { authApi } from './api'; import type { @@ -14,6 +15,7 @@ import type { SmsRegisterRequest, SmsRequest, TokenResponse, + UpdateNicknameRequest, UserInfo, } from './types'; @@ -24,6 +26,16 @@ export const authKeys = { tokenCheck: ['auth', 'token-check'] as const, }; +const PROFILE_QUERY_PREFIX = ['profile'] as const; + +function syncSessionAndProfileQueries( + queryClient: ReturnType, + user: UserInfo, +) { + queryClient.setQueryData(authKeys.session, user); + queryClient.invalidateQueries({ queryKey: PROFILE_QUERY_PREFIX }); +} + // ─── useSession ─── /** @@ -162,6 +174,44 @@ export function useSmsCode() { }); } +// ─── Avatar / nickname ─── + +export function useAvatarPresets() { + return useQuery({ + queryKey: ['avatar-presets'], + queryFn: () => authApi.fetchAvatarPresets(), + staleTime: 60 * 60 * 1000, + }); +} + +export function useUpdateNickname() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: (body: UpdateNicknameRequest) => authApi.updateNickname(body), + onSuccess: (user) => syncSessionAndProfileQueries(queryClient, user), + }); +} + +export function useUploadAvatar() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: (form: FormData) => authApi.uploadAvatar(form), + onSuccess: (user) => syncSessionAndProfileQueries(queryClient, user), + }); +} + +export function useSetAvatarPreset() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: (presetId: string) => + authApi.setAvatarPreset({ preset_id: presetId }), + onSuccess: (user) => syncSessionAndProfileQueries(queryClient, user), + }); +} + // ─── useLogout ─── /** @@ -185,6 +235,7 @@ export function useLogout() { } }, onSettled: async () => { + disposeAllBackgroundConversationWs(); await tokenManager.clearTokens(); queryClient.clear(); queryClient.setQueryData(authKeys.tokenCheck, false); diff --git a/app-expo/src/features/auth/types.ts b/app-expo/src/features/auth/types.ts index 1f10285..c1db232 100644 --- a/app-expo/src/features/auth/types.ts +++ b/app-expo/src/features/auth/types.ts @@ -79,6 +79,15 @@ export interface UpdateNicknameRequest { nickname: string; } +export interface AvatarPresetItem { + id: string; + url: string; +} + +export interface SetAvatarPresetRequest { + preset_id: string; +} + // ─── Session state ─── export type SessionStatus = diff --git a/app-expo/src/features/conversation/conversation-ws-background-pool.ts b/app-expo/src/features/conversation/conversation-ws-background-pool.ts new file mode 100644 index 0000000..858b680 --- /dev/null +++ b/app-expo/src/features/conversation/conversation-ws-background-pool.ts @@ -0,0 +1,88 @@ +import type { QueryClient } from '@tanstack/react-query'; +import { AppState, type AppStateStatus } from 'react-native'; + +import { RealtimeSession } from './realtime-session'; + +type Slot = { conversationId: string; session: RealtimeSession }; + +let slot: Slot | null = null; + +/** 与常见聊天 App 一致:仅当应用进入 background 时断开长连(避免后台挂 socket);inactive 不处理以减少控制中心等短暂打断 */ +let backgroundUnsubscribe: (() => void) | null = null; + +function installBackgroundLifecycleOnce(): void { + if (backgroundUnsubscribe) return; + const sub = AppState.addEventListener('change', (next: AppStateStatus) => { + if (next === 'background') { + disposeAllBackgroundConversationWs(); + } + }); + backgroundUnsubscribe = () => sub.remove(); +} + +function disposeSlot(): void { + if (!slot) return; + slot.session.dispose(); + slot = null; +} + +const offScreenUi = { + onStreamingText: () => {}, + onTtsSegment: () => {}, + onError: () => {}, + onStateChange: () => {}, +}; + +/** 离屏:保持 WebSocket,去掉 UI 回调,避免列表页播 TTS 或对已卸载组件 setState */ +export function releaseConversationWsUi(session: RealtimeSession): void { + session.attachUiCallbacks({ + onStreamingText: offScreenUi.onStreamingText, + onTtsSegment: offScreenUi.onTtsSegment, + onError: offScreenUi.onError, + onStateChange: offScreenUi.onStateChange, + }); +} + +/** 删除会话等场景:关闭对应长连 */ +export function disposeBackgroundConversationWs(conversationId: string): void { + if (slot?.conversationId === conversationId) { + disposeSlot(); + } +} + +/** 登出 / 清账号:关闭池中连接 */ +export function disposeAllBackgroundConversationWs(): void { + disposeSlot(); +} + +/** + * 单槽:仅保留「最近进入」的一个会话长连。换会话会dispose旧槽;同会话返回池中实例。 + */ +export function acquireBackgroundConversationWs( + conversationId: string, + queryClient: QueryClient, + prepared: RealtimeSession | null, +): RealtimeSession { + installBackgroundLifecycleOnce(); + if (prepared) { + if ( + slot && + (slot.conversationId !== conversationId || slot.session !== prepared) + ) { + disposeSlot(); + } + slot = { conversationId, session: prepared }; + return prepared; + } + + if (slot?.conversationId === conversationId) { + void slot.session.connect(); + return slot.session; + } + + disposeSlot(); + const session = new RealtimeSession({ conversationId, queryClient }); + slot = { conversationId, session }; + void session.connect(); + return session; +} diff --git a/app-expo/src/features/conversation/entry-warmup.ts b/app-expo/src/features/conversation/entry-warmup.ts new file mode 100644 index 0000000..579b38f --- /dev/null +++ b/app-expo/src/features/conversation/entry-warmup.ts @@ -0,0 +1,147 @@ +import type { QueryClient } from '@tanstack/react-query'; + +import { acquireBackgroundConversationWs } from './conversation-ws-background-pool'; +import { conversationMessagesRepository } from './conversation-messages-repository'; +import { conversationKeys } from './query-keys'; +import { registerPreparedRealtimeSession } from './prepared-session-registry'; +import { RealtimeSession } from './realtime-session'; +import type { MessageItem } from './types'; + +const OPENING_WARMUP_TIMEOUT_MS = 50_000; +const CACHE_POLL_MS = 120; + +function cacheHasAssistantMessage( + queryClient: QueryClient, + conversationId: string, +): boolean { + const data = queryClient.getQueryData( + conversationKeys.messages(conversationId), + ); + return (data ?? []).some((m) => m.senderType === 'assistant'); +} + +function waitForAssistantInCache( + queryClient: QueryClient, + conversationId: string, + timeoutMs: number, +): Promise { + return new Promise((resolve) => { + const deadline = Date.now() + timeoutMs; + const schedule = () => { + if (cacheHasAssistantMessage(queryClient, conversationId)) { + resolve(true); + return; + } + if (Date.now() >= deadline) { + resolve(false); + return; + } + setTimeout(schedule, CACHE_POLL_MS); + }; + schedule(); + }); +} + +export async function prefetchConversationMessages( + queryClient: QueryClient, + conversationId: string, +): Promise { + await queryClient.prefetchQuery({ + queryKey: conversationKeys.messages(conversationId), + queryFn: () => conversationMessagesRepository.loadMessages(conversationId), + }); +} + +const offscreenUiCallbacks = { + onStreamingText: () => {}, + onTtsSegment: () => {}, + onError: () => {}, + onStateChange: () => {}, +}; + +const inflightPrewarms = new Set(); + +/** + * 列表页/卡片按下时的预热:保持后台 WS 连接,并触发消息缓存填充。 + * 与 `warmupConversationOpening` 不同:不等待开场白、不阻塞调用方,仅适用于"已有消息"的会话。 + */ +export function prewarmConversationSession( + queryClient: QueryClient, + conversationId: string, +): void { + if (!conversationId) return; + const session = acquireBackgroundConversationWs( + conversationId, + queryClient, + null, + ); + // 预热阶段没有挂载的 UI,先用空回调占位;聊天页 mount 时会重新 attach。 + session.attachUiCallbacks(offscreenUiCallbacks); + if (inflightPrewarms.has(conversationId)) return; + const cached = queryClient.getQueryData( + conversationKeys.messages(conversationId), + ); + // 已有缓存就交给 React Query staleTime 决定是否刷新;只对首次进入做后台预取 + if (cached && cached.length > 0) return; + inflightPrewarms.add(conversationId); + void prefetchConversationMessages(queryClient, conversationId).finally(() => { + inflightPrewarms.delete(conversationId); + }); +} + +async function refreshConversationMessagesForWarmup( + queryClient: QueryClient, + conversationId: string, +): Promise { + await queryClient.fetchQuery({ + queryKey: conversationKeys.messages(conversationId), + queryFn: () => conversationMessagesRepository.loadMessages(conversationId), + staleTime: 0, + }); +} + +/** + * 在会话列表阶段连接 WS 并等待首条助手开场写入 React Query;成功后挂起会话供聊天页接棒。 + * 超时或失败则 dispose,由聊天页自行重连(服务端若已写入 history 不会重复开场)。 + */ +export async function warmupConversationOpening( + queryClient: QueryClient, + conversationId: string, +): Promise { + /** + * 先走 REST 历史预取:若 access token 已过期,API client 会在这里刷新 token; + * 也避免 Redis/DB 已有开场白但本地缓存仍为空时继续等 WS。 + */ + await refreshConversationMessagesForWarmup(queryClient, conversationId); + + if (cacheHasAssistantMessage(queryClient, conversationId)) { + return; + } + + const session = new RealtimeSession({ + conversationId, + queryClient, + }); + + session.attachUiCallbacks({ + onStreamingText: () => {}, + onTtsSegment: () => {}, + onError: () => {}, + onStateChange: () => {}, + }); + + await session.connect(); + + const ok = await waitForAssistantInCache( + queryClient, + conversationId, + OPENING_WARMUP_TIMEOUT_MS, + ); + + if (ok) { + registerPreparedRealtimeSession(conversationId, session); + await prefetchConversationMessages(queryClient, conversationId); + } else { + session.dispose(); + } +} diff --git a/app-expo/src/features/conversation/hooks.ts b/app-expo/src/features/conversation/hooks.ts index ae44534..239c465 100644 --- a/app-expo/src/features/conversation/hooks.ts +++ b/app-expo/src/features/conversation/hooks.ts @@ -1,17 +1,25 @@ import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { File, Paths } from 'expo-file-system'; import { useCallback, useEffect, useRef, useState } from 'react'; +import { AppState, type AppStateStatus } from 'react-native'; import type { TopicSuggestion, WsConnectionState } from '@/core/ws/types'; import { conversationApi } from './api'; +import { + acquireBackgroundConversationWs, + disposeAllBackgroundConversationWs, + disposeBackgroundConversationWs, + releaseConversationWsUi, +} from './conversation-ws-background-pool'; import { conversationMessagesRepository } from './conversation-messages-repository'; import { conversationKeys } from './query-keys'; +import { takePreparedRealtimeSession } from './prepared-session-registry'; import { - RealtimeSession, type ErrorCallback, type StreamingTextCallback, type TtsSegmentPayload, + type RealtimeSession, } from './realtime-session'; import { type ConversationListItem, @@ -126,6 +134,7 @@ export function useDeleteConversation() { mutationFn: (conversationId: string) => conversationApi.delete(conversationId), onSuccess: async (_, conversationId) => { + disposeBackgroundConversationWs(conversationId); await voiceSegmentStore.clearConversation(conversationId); queryClient.setQueryData( conversationKeys.lists(), @@ -165,6 +174,8 @@ interface UseRealtimeSessionOptions { onTtsSegment?: (payload: TtsSegmentPayload) => void; /** 用户发出下一条文本/语音成功后调用,用于恢复接受 TTS 片段(打断后丢弃迟到片段) */ onTtsPlaybackResume?: () => void; + /** 本条发送是否请求了「本轮助手朗读」,用于仅在该轮自动播放 WS TTS */ + onUserSendTtsPreference?: (requestedTts: boolean) => void; } const MIN_RECORDING_DURATION_SEC = 1; @@ -190,6 +201,11 @@ interface RealtimeSessionState { sendVoiceMessage: (uri: string, durationMs: number) => Promise; sendEndConversation: () => void; sendTtsCancel: () => void; + requestAssistantSegmentTts: (body: { + assistantMessageId: string; + segmentIndex: number; + segmentText?: string; + }) => boolean; } export function useRealtimeSession({ @@ -197,9 +213,17 @@ export function useRealtimeSession({ enabled = true, onTtsSegment, onTtsPlaybackResume, + onUserSendTtsPreference, }: UseRealtimeSessionOptions): RealtimeSessionState { const queryClient = useQueryClient(); const sessionRef = useRef(null); + const uiRef = useRef({ + handleStreamingText: (() => {}) as StreamingTextCallback, + handleError: (() => {}) as ErrorCallback, + onTtsSegment: undefined as + | ((payload: TtsSegmentPayload) => void) + | undefined, + }); const [connectionState, setConnectionState] = useState('disconnected'); @@ -211,6 +235,10 @@ export function useRealtimeSession({ [], ); + const [foregroundResumeGeneration, setForegroundResumeGeneration] = + useState(0); + const needsResumeAfterBackgroundRef = useRef(false); + const handleStreamingText: StreamingTextCallback = useCallback( (text, isComplete) => { if (text.trim().length > 0) { @@ -245,7 +273,8 @@ export function useRealtimeSession({ useEffect(() => { if (!enabled || !conversationId) return; - const session = new RealtimeSession({ + const prepared = takePreparedRealtimeSession(conversationId); + const session = acquireBackgroundConversationWs( conversationId, queryClient, onStreamingText: handleStreamingText, @@ -256,10 +285,10 @@ export function useRealtimeSession({ }); sessionRef.current = session; - session.connect(); + setConnectionState(session.getConnectionState()); return () => { - session.dispose(); + releaseConversationWsUi(session); sessionRef.current = null; setConnectionState('disconnected'); setStreamingMessage(null); @@ -277,15 +306,17 @@ export function useRealtimeSession({ ]); const sendText = useCallback( - (text: string) => { + (text: string, options?: { ttsThisTurn?: boolean }) => { if (!sessionRef.current) return; - const sent = sessionRef.current.sendText(text); + const sent = sessionRef.current.sendText(text, options); if (!sent) { setError('消息发送失败,连接未就绪'); return; } + onUserSendTtsPreference?.(options?.ttsThisTurn === true); + setAwaitingAssistantReply(true); setTopicSuggestions([]); onTtsPlaybackResume?.(); @@ -319,11 +350,15 @@ export function useRealtimeSession({ }, ); }, - [conversationId, queryClient, onTtsPlaybackResume], + [conversationId, queryClient, onTtsPlaybackResume, onUserSendTtsPreference], ); const sendVoiceMessage = useCallback( - async (uri: string, durationMs: number): Promise => { + async ( + uri: string, + durationMs: number, + options?: { ttsThisTurn?: boolean }, + ): Promise => { const session = sessionRef.current; if (!session) return false; @@ -340,12 +375,15 @@ export function useRealtimeSession({ clientSegmentId: `${voiceSessionId}-0`, isLast: true, duration: durationSec, + ttsThisTurn: options?.ttsThisTurn, }); if (!sent) { setError('语音发送失败,连接未就绪'); return false; } + onUserSendTtsPreference?.(options?.ttsThisTurn === true); + setAwaitingAssistantReply(true); setTopicSuggestions([]); const localId = `pending_voice_${Date.now()}`; @@ -391,7 +429,7 @@ export function useRealtimeSession({ return false; } }, - [conversationId, queryClient, onTtsPlaybackResume], + [conversationId, queryClient, onTtsPlaybackResume, onUserSendTtsPreference], ); const sendEndConversation = useCallback(() => { @@ -402,6 +440,15 @@ export function useRealtimeSession({ sessionRef.current?.sendTtsCancel(); }, []); + const requestAssistantSegmentTts = useCallback( + (body: { + assistantMessageId: string; + segmentIndex: number; + segmentText?: string; + }) => sessionRef.current?.requestAssistantSegmentTts(body) ?? false, + [], + ); + return { connectionState, streamingMessage, @@ -413,5 +460,6 @@ export function useRealtimeSession({ sendVoiceMessage, sendEndConversation, sendTtsCancel, + requestAssistantSegmentTts, }; } diff --git a/app-expo/src/features/conversation/prepared-session-registry.ts b/app-expo/src/features/conversation/prepared-session-registry.ts new file mode 100644 index 0000000..dca09b8 --- /dev/null +++ b/app-expo/src/features/conversation/prepared-session-registry.ts @@ -0,0 +1,33 @@ +import type { RealtimeSession } from './realtime-session'; + +const preparedByConversationId = new Map(); + +/** 列表页预热完成后挂起会话,聊天页挂载时接棒并删除登记 */ +export function registerPreparedRealtimeSession( + conversationId: string, + session: RealtimeSession, +): void { + const old = preparedByConversationId.get(conversationId); + if (old && old !== session) { + old.dispose(); + } + preparedByConversationId.set(conversationId, session); +} + +/** 取出即视为消费;若无则返回 null */ +export function takePreparedRealtimeSession( + conversationId: string, +): RealtimeSession | null { + const session = preparedByConversationId.get(conversationId); + if (!session) return null; + preparedByConversationId.delete(conversationId); + return session; +} + +/** 预热成功后导航失败时释放挂起连接,避免僵尸 WebSocket */ +export function abandonPreparedRealtimeSession(conversationId: string): void { + const session = preparedByConversationId.get(conversationId); + if (!session) return; + preparedByConversationId.delete(conversationId); + session.dispose(); +} diff --git a/app-expo/src/features/conversation/realtime-session.ts b/app-expo/src/features/conversation/realtime-session.ts index 63f4629..8c625e1 100644 --- a/app-expo/src/features/conversation/realtime-session.ts +++ b/app-expo/src/features/conversation/realtime-session.ts @@ -16,6 +16,13 @@ import { assistantSegmentMessageId, lastSegmentPreview } from './message-split'; import { conversationKeys } from './query-keys'; import type { ConversationListItem, MessageItem } from './types'; +/** 与落库助手消息 id、会话页 `durableAssistantIdForBubble` 的 uuid 判断一致 */ +function looksLikeUuidAssistantMessageId(id: string): boolean { + return /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test( + id, + ); +} + export type StreamingTextCallback = (text: string, isComplete: boolean) => void; export type ErrorCallback = (message: string, code?: string) => void; export type TopicSuggestionsCallback = (payload: { @@ -32,6 +39,8 @@ export type TtsSegmentPayload = { total?: number; /** 服务端持久化后的助手消息 id,用于与气泡 listKey / 消息 id 对齐 */ assistantMessageId?: string; + /** 用户点喇叭按需下发时为 true,应加入播放队列(即使未开「本轮朗读」) */ + manual?: boolean; }; interface RealtimeSessionOptions { @@ -65,12 +74,25 @@ export class RealtimeSession { private onTtsSegment?: (payload: TtsSegmentPayload) => void; private onTopicSuggestions?: TopicSuggestionsCallback; private onError?: ErrorCallback; + private uiStateListener?: WsStateListener; private unsubEvent: (() => void) | null = null; private unsubState: (() => void) | null = null; private streamingBuffer = ''; /** 单段回复且服务端带 `assistant_message_id` 时用于落缓存 id */ private pendingAssistantMessageId: string | null = null; + private destroyed = false; + + /** 本条用户消息是否请求「先 TTS 再出字」的助手轮次 */ + private assistantTurnTtsSync = false; + private pendingTtsByKey = new Map(); + + private static bufferedTtsKey( + assistantMessageId: string | undefined, + index: number, + ): string { + return `${assistantMessageId ?? '_'}:${index}`; + } constructor(options: RealtimeSessionOptions) { this.client = new WsClient(options.conversationId); @@ -80,11 +102,38 @@ export class RealtimeSession { this.onTtsSegment = options.onTtsSegment; this.onTopicSuggestions = options.onTopicSuggestions; this.onError = options.onError; + this.uiStateListener = options.onStateChange; this.unsubEvent = this.client.onEvent(this.handleEvent); - if (options.onStateChange) { - this.unsubState = this.client.onStateChange(options.onStateChange); + this.unsubState = this.client.onStateChange((state) => { + this.uiStateListener?.(state); + }); + } + + /** 列表预热接棒或刷新 UI 订阅时替换回调,不重建 WebSocket */ + attachUiCallbacks(options: { + onStreamingText?: StreamingTextCallback; + onTtsSegment?: (payload: TtsSegmentPayload) => void; + onError?: ErrorCallback; + onStateChange?: WsStateListener; + }): void { + if (this.destroyed) return; + if (options.onStreamingText !== undefined) { + this.onStreamingText = options.onStreamingText; + } + if (options.onTtsSegment !== undefined) { + this.onTtsSegment = options.onTtsSegment; + } + if (options.onError !== undefined) { + this.onError = options.onError; + } + if (options.onStateChange !== undefined) { + this.uiStateListener = options.onStateChange; + options.onStateChange(this.client.getState()); + } + if (!this.assistantTurnTtsSync && this.streamingBuffer.trim().length > 0) { + this.onStreamingText?.(this.streamingBuffer, false); } } @@ -97,15 +146,20 @@ export class RealtimeSession { } dispose(): void { + if (this.destroyed) return; + this.destroyed = true; this.flushStreamingBufferIfPending(); + this.resetAssistantTtsSyncState(); this.unsubEvent?.(); this.unsubState?.(); this.client.dispose(); } /** Returns true if the message was sent over the socket. */ - sendText(text: string): boolean { - return this.client.sendText(text); + sendText(text: string, options?: { ttsThisTurn?: boolean }): boolean { + const tts = !!options?.ttsThisTurn; + this.assistantTurnTtsSync = tts; + return this.client.sendText(text, { ttsThisTurn: tts }); } sendAudioSegment( @@ -116,8 +170,11 @@ export class RealtimeSession { clientSegmentId?: string; isLast?: boolean; duration?: number; + ttsThisTurn?: boolean; }, ): boolean { + const tts = !!options?.ttsThisTurn; + this.assistantTurnTtsSync = tts; return this.client.send({ type: 'audio_segment', data: { @@ -127,6 +184,7 @@ export class RealtimeSession { client_segment_id: options?.clientSegmentId, is_last: options?.isLast, duration: options?.duration, + ...(options?.ttsThisTurn === true ? { tts_this_turn: true } : {}), }, }); } @@ -137,6 +195,7 @@ export class RealtimeSession { /** 通知服务端停止当前轮次后续 TTS 合成与下发(与客户端 stop 队列配合) */ sendTtsCancel(): boolean { + this.resetAssistantTtsSyncState(); return this.client.sendTtsCancel(); } @@ -144,8 +203,40 @@ export class RealtimeSession { return this.client.getState(); } + requestAssistantSegmentTts(body: { + assistantMessageId: string; + segmentIndex: number; + segmentText?: string; + }): boolean { + return this.client.sendTtsRequest(body); + } + // ─── Internal ─── + private resetAssistantTtsSyncState(): void { + this.assistantTurnTtsSync = false; + this.pendingTtsByKey.clear(); + } + + private flushBufferedTtsIfSync( + assistantMessageId: string | undefined, + index: number, + ): void { + if (!this.assistantTurnTtsSync) return; + const key = RealtimeSession.bufferedTtsKey(assistantMessageId, index); + const payload = this.pendingTtsByKey.get(key); + if (payload) { + this.pendingTtsByKey.delete(key); + this.onTtsSegment?.(payload); + } + } + + private finishAssistantTurnIfLastSegment(index: number, total: number): void { + if (index >= total - 1) { + this.resetAssistantTtsSyncState(); + } + } + private handleEvent: WsEventListener = (event: WsEvent) => { if (event.kind === 'agent_response') { this.handleAgentChunk(event); @@ -155,14 +246,26 @@ export class RealtimeSession { if (event.kind === 'tts_audio_received') { const b64 = event.audioBase64?.trim(); const url = event.audioUrl?.trim(); - if (b64 || url) { - this.onTtsSegment?.({ - audioBase64: b64 || undefined, - audioUrl: url || undefined, - index: event.index, - total: event.total, - assistantMessageId: event.assistantMessageId, - }); + if (!b64 && !url) { + return; + } + const payload: TtsSegmentPayload = { + audioBase64: b64 || undefined, + audioUrl: url || undefined, + index: event.index, + total: event.total, + assistantMessageId: event.assistantMessageId, + manual: event.manual, + }; + if (this.assistantTurnTtsSync && !payload.manual) { + const idx = event.index ?? 0; + const key = RealtimeSession.bufferedTtsKey( + event.assistantMessageId, + idx, + ); + this.pendingTtsByKey.set(key, payload); + } else { + this.onTtsSegment?.(payload); } return; } @@ -179,6 +282,7 @@ export class RealtimeSession { handleWsEvent(this.queryClient, event); if (event.kind === 'session_error') { + this.resetAssistantTtsSyncState(); this.onError?.(event.message, event.code); } }; @@ -202,14 +306,19 @@ export class RealtimeSession { const total = event.total ?? 1; const index = event.index ?? 0; + const sync = this.assistantTurnTtsSync; if (total > 1) { const id = event.assistantMessageId != null ? assistantSegmentMessageId(event.assistantMessageId, index) : `${this.conversationId}_agent_${Date.now()}_${index}`; + if (sync) { + this.flushBufferedTtsIfSync(event.assistantMessageId, index); + } this.commitOneAssistantMessage(event.text, id); this.onStreamingText?.(event.text, true); + this.finishAssistantTurnIfLastSegment(index, total); return; } @@ -218,18 +327,30 @@ export class RealtimeSession { } this.streamingBuffer += event.text; - - // 与 coerced index/total 对齐:若服务端只带 text、省略 index/total,旧逻辑会 isComplete=false,永远不落库 const isComplete = index >= total - 1; - this.onStreamingText?.(this.streamingBuffer, isComplete); + if (!sync) { + this.onStreamingText?.(this.streamingBuffer, isComplete); + } if (isComplete) { + const assistantId = + event.assistantMessageId ?? this.pendingAssistantMessageId; const id = this.pendingAssistantMessageId ?? `${this.conversationId}_agent_${Date.now()}`; - this.commitStreamingBufferWithId(id); + if (sync) { + this.flushBufferedTtsIfSync(assistantId ?? undefined, 0); + this.commitStreamingBufferWithId(id); + const visible = + this.streamingBuffer.trim().length > 0 ? this.streamingBuffer : '…'; + this.onStreamingText?.(visible, true); + } else { + this.commitStreamingBufferWithId(id); + } + this.streamingBuffer = ''; this.pendingAssistantMessageId = null; + this.finishAssistantTurnIfLastSegment(0, 1); } } @@ -243,6 +364,9 @@ export class RealtimeSession { senderType: 'assistant', timestamp: Date.now(), messageType: 'text', + ...(looksLikeUuidAssistantMessageId(id) + ? { durableMessageId: id } + : {}), }; return [...(old ?? []), message]; }); @@ -256,7 +380,6 @@ export class RealtimeSession { } const fullText = this.streamingBuffer; - this.streamingBuffer = ''; const content = fullText.trim().length > 0 ? fullText : '…'; const messagesKey = conversationKeys.messages(this.conversationId); @@ -268,6 +391,9 @@ export class RealtimeSession { senderType: 'assistant', timestamp: Date.now(), messageType: 'text', + ...(looksLikeUuidAssistantMessageId(messageId) + ? { durableMessageId: messageId } + : {}), }; return [...(old ?? []), message]; }); @@ -299,6 +425,7 @@ export class RealtimeSession { this.pendingAssistantMessageId ?? `${this.conversationId}_agent_${Date.now()}`; this.commitStreamingBufferWithId(id); + this.streamingBuffer = ''; this.pendingAssistantMessageId = null; } } diff --git a/app-expo/src/features/conversation/types.ts b/app-expo/src/features/conversation/types.ts index dcd3930..377e247 100644 --- a/app-expo/src/features/conversation/types.ts +++ b/app-expo/src/features/conversation/types.ts @@ -72,6 +72,8 @@ export interface MessageItem { audioUri?: string; /** 助手 TTS 已上传的 COS URL 列表(与后端 `ttsAudioUrls` 一致),用于不重合成重复朗读 */ ttsAudioUrls?: string[]; + /** 落库后的助手消息 id(REST 历史同步),用于按需 TTS 请求 */ + durableMessageId?: string; } export interface OrganizeResponse { diff --git a/app-expo/src/features/profile/hooks.ts b/app-expo/src/features/profile/hooks.ts index d41fbd9..0954b75 100644 --- a/app-expo/src/features/profile/hooks.ts +++ b/app-expo/src/features/profile/hooks.ts @@ -2,6 +2,7 @@ import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { router } from 'expo-router'; import { tokenManager } from '@/core/auth/token-manager'; +import { disposeAllBackgroundConversationWs } from '@/features/conversation/conversation-ws-background-pool'; import { authKeys } from '@/features/auth/hooks'; import { profileApi } from './api'; @@ -97,6 +98,7 @@ export function usePurgeUserData() { return useMutation({ mutationFn: (body: PurgeUserDataRequest) => profileApi.purgeUserData(body), onSuccess: async () => { + disposeAllBackgroundConversationWs(); await tokenManager.clearTokens(); queryClient.clear(); queryClient.setQueryData(authKeys.tokenCheck, false); diff --git a/app-expo/src/features/voice/hooks/use-player.ts b/app-expo/src/features/voice/hooks/use-player.ts index d418ddc..8bb257f 100644 --- a/app-expo/src/features/voice/hooks/use-player.ts +++ b/app-expo/src/features/voice/hooks/use-player.ts @@ -15,6 +15,10 @@ interface UsePlayerResult { enqueue: (item: PlaybackItem) => void; /** Replace queue and play this item (e.g. user voice bubble vs other sources). */ enqueueExclusive: (item: PlaybackItem) => Promise; + /** Pause native playback without draining queue(与 stop 清空队列不同)。 */ + pausePlayback: () => void; + /** Continue after pausePlayback(需 status === 'paused') */ + resumePlayback: () => void; stop: () => void; } @@ -68,9 +72,11 @@ export function usePlayer(): UsePlayerResult { useEffect(() => { if (!currentSource || !player) return; if (!playerStatus.isLoaded) return; + /** 先于 isLoaded「抢暂停」时需保留暂停,避免本条自动 play 覆盖 pause */ + if (status === 'paused') return; player.play(); isPlayingRef.current = true; - }, [currentSource, player, playerStatus.isLoaded]); + }, [currentSource, player, playerStatus.isLoaded, status]); const playNext = useCallback(async () => { if (isPlayNextInProgressRef.current) return; @@ -114,6 +120,7 @@ export function usePlayer(): UsePlayerResult { // Detect playback completion → advance queue(必须曾 playing,避免换源瞬间沿用上一条的 duration/currentTime) useEffect(() => { + if (status === 'paused') return; if (!currentSource || !isPlayingRef.current) return; const { playing, currentTime, duration } = playerStatus; @@ -128,7 +135,32 @@ export function usePlayer(): UsePlayerResult { isPlayingRef.current = false; playNext(); } - }, [playerStatus, currentSource, playNext]); + }, [playerStatus, currentSource, playNext, status]); + + const pausePlayback = useCallback(() => { + setStatus((s) => { + if (s !== 'playing') return s; + if (player) { + player.pause(); + } + isPlayingRef.current = false; + return 'paused'; + }); + }, [player]); + + const resumePlayback = useCallback(async () => { + if (status !== 'paused') return; + const acquired = await audioFocus.acquireForPlayback(); + if (!acquired) { + setStatus('idle'); + return; + } + if (!player) return; + if (!playerStatus.isLoaded) return; + player.play(); + setStatus('playing'); + isPlayingRef.current = true; + }, [status, player, playerStatus.isLoaded]); // Subscribe to audioFocus owner changes for recorder → idle recovery useEffect(() => { @@ -205,6 +237,8 @@ export function usePlayer(): UsePlayerResult { currentPlaybackItem, enqueue, enqueueExclusive, + pausePlayback, + resumePlayback, stop, }; } diff --git a/app-expo/src/i18n/generated/resources.ts b/app-expo/src/i18n/generated/resources.ts index 94ed1a6..c499901 100644 --- a/app-expo/src/i18n/generated/resources.ts +++ b/app-expo/src/i18n/generated/resources.ts @@ -86,6 +86,11 @@ interface Resources { inputPlaceholderVoice: 'Type here or hold the mic to speak...'; me: 'Me'; readAloudAgain: 'Play again'; + readAloudPause: 'Pause reading'; + readAloudResume: 'Resume reading'; + readAloudRequest: 'Read aloud'; + readAloudRequestFailed: 'Could not start playback. Check your connection.'; + readAloudNoMessageId: 'This message is not ready for on-demand reading yet. Pull to refresh or try again.'; readingAloud: 'Reading aloud…'; recentChats: 'Recent Chats'; recordingPermissionDenied: 'Microphone permission is required to record'; @@ -203,6 +208,29 @@ interface Resources { feedbackPageTitle: 'Share your thoughts'; title: 'Help & Support'; }; + personalInfo: { + avatarPresetFailed: 'Could not set preset avatar'; + avatarUploadFailed: 'Could not upload avatar'; + birthPlacePlaceholder: 'Birthplace'; + birthYearPlaceholder: 'Birth year'; + cancel: 'Cancel'; + changeAvatar: 'Change photo'; + chooseFromLibrary: 'Choose from library'; + choosePreset: 'Preset avatars'; + grewUpPlaceholder: 'Where you grew up'; + libraryPermissionDenied: 'Photo library access is required to pick an image'; + nickname: 'Nickname'; + nicknamePlaceholder: 'Enter nickname'; + nicknameRequired: 'Please enter a nickname'; + occupationPlaceholder: 'Occupation'; + presetPickTitle: 'Choose a preset'; + save: 'Save'; + saveFailed: 'Could not save'; + savePartialBody: 'Your nickname was saved, but profile fields below could not be saved. Check your connection and tap Save again.'; + savePartialTitle: 'Partially saved'; + saving: 'Saving…'; + title: 'Personal info'; + }; signOut: 'Sign Out'; signingOut: 'Signing out...'; userNamePlaceholder: 'User'; diff --git a/app-expo/src/i18n/locales/en/conversation.json b/app-expo/src/i18n/locales/en/conversation.json index a54eeac..7586f10 100644 --- a/app-expo/src/i18n/locales/en/conversation.json +++ b/app-expo/src/i18n/locales/en/conversation.json @@ -27,10 +27,17 @@ "recentChats": "Recent Chats", "stopReadingAloud": "Stop reading aloud", "readAloudAgain": "Play again", + "readAloudPause": "Pause reading", + "readAloudResume": "Resume reading", + "readAloudRequest": "Read aloud", + "readAloudRequestFailed": "Could not start playback. Check your connection.", + "readAloudNoMessageId": "This message is not ready for on-demand reading yet. Pull to refresh or try again.", "cannotReadAloud": "Read unavailable", "readingAloud": "Reading aloud…", "recordingPermissionDenied": "Microphone permission is required to record", "recordingStartFailed": "Unable to start recording. Please try again.", + "ttsThisTurn": "Speak", + "ttsThisTurnAccessibility": "When on, assistant replies synthesize speech before text appears.", "send": "Send", "startNewSubtitle": "Capture a new memory or share your thoughts with your companion.", "switchToText": "Switch to text input", diff --git a/app-expo/src/i18n/locales/en/profile.json b/app-expo/src/i18n/locales/en/profile.json index 6bfc981..c9d4cdf 100644 --- a/app-expo/src/i18n/locales/en/profile.json +++ b/app-expo/src/i18n/locales/en/profile.json @@ -33,6 +33,29 @@ "title": "Data & Privacy" }, "editAvatar": "Edit Profile Picture", + "personalInfo": { + "avatarPresetFailed": "Could not set preset avatar", + "avatarUploadFailed": "Could not upload avatar", + "cancel": "Cancel", + "birthPlacePlaceholder": "Birthplace", + "birthYearPlaceholder": "Birth year", + "changeAvatar": "Change photo", + "chooseFromLibrary": "Choose from library", + "choosePreset": "Preset avatars", + "grewUpPlaceholder": "Where you grew up", + "libraryPermissionDenied": "Photo library access is required to pick an image", + "nickname": "Nickname", + "nicknamePlaceholder": "Enter nickname", + "nicknameRequired": "Please enter a nickname", + "occupationPlaceholder": "Occupation", + "presetPickTitle": "Choose a preset", + "save": "Save", + "saveFailed": "Could not save", + "savePartialBody": "Your nickname was saved, but profile fields below could not be saved. Check your connection and tap Save again.", + "savePartialTitle": "Partially saved", + "saving": "Saving…", + "title": "Personal info" + }, "helpSupport": { "faq": "FAQ", "feedback": "Feedback & Support", diff --git a/app-expo/src/i18n/locales/zh/conversation.json b/app-expo/src/i18n/locales/zh/conversation.json index 2621fbe..93720fc 100644 --- a/app-expo/src/i18n/locales/zh/conversation.json +++ b/app-expo/src/i18n/locales/zh/conversation.json @@ -27,10 +27,17 @@ "recentChats": "最近对话", "stopReadingAloud": "停止朗读", "readAloudAgain": "再读", + "readAloudPause": "暂停朗读", + "readAloudResume": "继续朗读", + "readAloudRequest": "朗读", + "readAloudRequestFailed": "无法开始朗读,请检查网络或稍后重试。", + "readAloudNoMessageId": "该条尚未可朗读,请下拉刷新或稍后再试。", "cannotReadAloud": "暂无法朗读", "readingAloud": "朗读中…", "recordingPermissionDenied": "需要麦克风权限才能录音", "recordingStartFailed": "录音初始化失败,请重试", + "ttsThisTurn": "朗读", + "ttsThisTurnAccessibility": "开启后本条消息的助手回复将先合成语音再显示文字", "send": "发送", "startNewSubtitle": "记录新回忆,或与岁月知己分享你的想法。", "switchToText": "切换到文字输入", diff --git a/app-expo/src/i18n/locales/zh/profile.json b/app-expo/src/i18n/locales/zh/profile.json index 9ed5b81..e993e8c 100644 --- a/app-expo/src/i18n/locales/zh/profile.json +++ b/app-expo/src/i18n/locales/zh/profile.json @@ -33,6 +33,29 @@ "title": "数据与隐私" }, "editAvatar": "编辑头像", + "personalInfo": { + "avatarPresetFailed": "设置预设头像失败", + "avatarUploadFailed": "上传头像失败", + "cancel": "取消", + "birthPlacePlaceholder": "出生地", + "birthYearPlaceholder": "出生年份", + "changeAvatar": "更换头像", + "chooseFromLibrary": "从相册选择", + "choosePreset": "预设头像", + "grewUpPlaceholder": "成长地", + "libraryPermissionDenied": "需要相册权限才能选择图片", + "nickname": "昵称", + "nicknamePlaceholder": "请输入昵称", + "nicknameRequired": "请填写昵称", + "occupationPlaceholder": "职业", + "presetPickTitle": "选择预设", + "save": "保存", + "saveFailed": "保存失败", + "savePartialBody": "昵称已更新,但下面的档案字段未能保存。请检查网络后再次点击保存。", + "savePartialTitle": "部分保存成功", + "saving": "保存中…", + "title": "个人信息" + }, "helpSupport": { "faq": "常见问题", "feedback": "反馈与客服", diff --git a/app-expo/tests/core/ws/client.test.ts b/app-expo/tests/core/ws/client.test.ts index ad46e39..220d2dc 100644 --- a/app-expo/tests/core/ws/client.test.ts +++ b/app-expo/tests/core/ws/client.test.ts @@ -9,7 +9,7 @@ jest.mock('@/core/auth/token-manager', () => ({ jest.mock('@/core/config', () => ({ config: { - wsBaseUrl: 'ws://localhost:8000', + wsBaseUrl: 'ws://localhost:8000/', ws: { reconnectMaxRetries: 3, reconnectBaseDelayMs: 10, @@ -23,6 +23,7 @@ jest.mock('@/core/config', () => ({ class MockWebSocket { static OPEN = 1; static CLOSED = 3; + static instances: MockWebSocket[] = []; readyState = MockWebSocket.OPEN; onopen: (() => void) | null = null; @@ -32,6 +33,7 @@ class MockWebSocket { sentMessages: string[] = []; constructor(public url: string) { + MockWebSocket.instances.push(this); setTimeout(() => this.onopen?.(), 0); } @@ -53,6 +55,7 @@ class MockWebSocket { describe('WsClient', () => { afterEach(() => { jest.clearAllMocks(); + MockWebSocket.instances = []; }); test('connects with token and conversation id in URL', async () => { @@ -66,6 +69,9 @@ describe('WsClient', () => { expect(states).toContain('connecting'); expect(states).toContain('connected'); + expect(MockWebSocket.instances[0]?.url).toBe( + 'ws://localhost:8000/ws/conversation/conv-123?token=test-token', + ); client.dispose(); }); @@ -133,6 +139,27 @@ describe('WsClient', () => { client.dispose(); }); + test('sends text with tts_this_turn when requested', async () => { + const client = new WsClient('conv-123'); + + await client.connect(); + await new Promise((r) => setTimeout(r, 10)); + + client.sendText('Hello', { ttsThisTurn: true }); + + const ws = (client as unknown as { ws: MockWebSocket }).ws; + expect(ws.sentMessages).toHaveLength(1); + + const sent = JSON.parse(ws.sentMessages[0]); + expect(sent).toEqual({ + type: 'text', + conversation_id: 'conv-123', + data: { text: 'Hello', tts_this_turn: true }, + }); + + client.dispose(); + }); + test('ignores unknown message types without crashing', async () => { const client = new WsClient('conv-123'); const events: WsEvent[] = []; diff --git a/app-expo/tests/features/conversation/entry-warmup.test.ts b/app-expo/tests/features/conversation/entry-warmup.test.ts new file mode 100644 index 0000000..ce9d2f4 --- /dev/null +++ b/app-expo/tests/features/conversation/entry-warmup.test.ts @@ -0,0 +1,126 @@ +import { QueryClient } from '@tanstack/react-query'; + +import { + prefetchConversationMessages, + warmupConversationOpening, +} from '@/features/conversation/entry-warmup'; +import { conversationKeys } from '@/features/conversation/query-keys'; +import type { MessageItem } from '@/features/conversation/types'; + +const mockLoadMessages = jest.fn(); +const mockRegisterPreparedRealtimeSession = jest.fn(); +let mockConnectImpl: + | ((options: { + conversationId: string; + queryClient: QueryClient; + }) => Promise | void) + | null = null; +const mockSessions: Array<{ + attachUiCallbacks: jest.Mock; + connect: jest.Mock; + dispose: jest.Mock; +}> = []; + +jest.mock('@/features/conversation/conversation-messages-repository', () => ({ + conversationMessagesRepository: { + loadMessages: (conversationId: string) => mockLoadMessages(conversationId), + }, +})); + +jest.mock('@/features/conversation/prepared-session-registry', () => ({ + registerPreparedRealtimeSession: (conversationId: string, session: unknown) => + mockRegisterPreparedRealtimeSession(conversationId, session), +})); + +jest.mock('@/features/conversation/realtime-session', () => ({ + RealtimeSession: jest.fn().mockImplementation((options) => { + const session = { + attachUiCallbacks: jest.fn(), + connect: jest.fn(async () => { + await mockConnectImpl?.(options); + }), + dispose: jest.fn(), + }; + mockSessions.push(session); + return session; + }), +})); + +function createQueryClient(): QueryClient { + return new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: Infinity }, + mutations: { retry: false }, + }, + }); +} + +function assistantMessage(id = 'assistant-1'): MessageItem { + return { + id, + conversationId: 'conv-1', + content: '你好,今天想聊哪段回忆?', + senderType: 'assistant', + timestamp: 1, + messageType: 'text', + }; +} + +describe('conversation entry warmup', () => { + let queryClient: QueryClient; + + beforeEach(() => { + queryClient = createQueryClient(); + mockLoadMessages.mockReset(); + mockRegisterPreparedRealtimeSession.mockReset(); + mockConnectImpl = null; + mockSessions.length = 0; + }); + + afterEach(async () => { + await queryClient.cancelQueries(); + queryClient.clear(); + }); + + test('prefetches messages without throwing on load failure', async () => { + mockLoadMessages.mockRejectedValueOnce(new Error('network down')); + + await expect( + prefetchConversationMessages(queryClient, 'conv-1'), + ).resolves.toBeUndefined(); + }); + + test('uses refreshed history and skips websocket when opening is already cached', async () => { + const existing = assistantMessage(); + mockLoadMessages.mockResolvedValueOnce([existing]); + + await warmupConversationOpening(queryClient, 'conv-1'); + + expect(mockLoadMessages).toHaveBeenCalledWith('conv-1'); + expect(mockSessions).toHaveLength(0); + expect( + queryClient.getQueryData(conversationKeys.messages('conv-1')), + ).toEqual([existing]); + }); + + test('connects websocket and registers prepared session after opening arrives', async () => { + const opened = assistantMessage(); + mockLoadMessages.mockResolvedValueOnce([]).mockResolvedValueOnce([opened]); + mockConnectImpl = ({ conversationId, queryClient }) => { + queryClient.setQueryData(conversationKeys.messages(conversationId), [ + opened, + ]); + }; + + await warmupConversationOpening(queryClient, 'conv-1'); + + expect(mockSessions).toHaveLength(1); + expect(mockSessions[0]?.attachUiCallbacks).toHaveBeenCalled(); + expect(mockSessions[0]?.connect).toHaveBeenCalled(); + expect(mockSessions[0]?.dispose).not.toHaveBeenCalled(); + expect(mockRegisterPreparedRealtimeSession).toHaveBeenCalledWith( + 'conv-1', + mockSessions[0], + ); + }); +}); diff --git a/app-expo/tests/features/voice/use-player.test.tsx b/app-expo/tests/features/voice/use-player.test.tsx index 3122b94..15f90c1 100644 --- a/app-expo/tests/features/voice/use-player.test.tsx +++ b/app-expo/tests/features/voice/use-player.test.tsx @@ -1,5 +1,6 @@ -import { renderHook } from '@testing-library/react-native'; +import { act, renderHook } from '@testing-library/react-native'; +import { audioFocus } from '@/core/audio/audio-focus'; import { usePlayer } from '@/features/voice/hooks/use-player'; const mockUseAudioPlayer = jest.fn(); @@ -34,6 +35,8 @@ describe('usePlayer', () => { currentTime: 0, duration: 0, }); + jest.mocked(audioFocus.acquireForPlayback).mockResolvedValue(true); + jest.mocked(audioFocus.releaseIfOwnedBy).mockResolvedValue(undefined); }); test('keeps the native audio session active while app-level audio focus owns teardown', () => { @@ -47,4 +50,81 @@ describe('usePlayer', () => { }), ); }); + + test('pausePlayback toggles playing→paused and invokes native pause', async () => { + mockUseAudioPlayerStatus.mockReturnValue({ + isLoaded: true, + playing: true, + currentTime: 0.1, + duration: 10, + }); + const pause = jest.fn(); + const play = jest.fn(); + mockUseAudioPlayer.mockReturnValue({ pause, play }); + + const { result } = renderHook(() => usePlayer()); + + await act(async () => { + await result.current.enqueueExclusive({ + uri: 'file:///fixture.mp3', + kind: 'voice', + }); + }); + + expect(result.current.status).toBe('playing'); + + act(() => { + result.current.pausePlayback(); + }); + + expect(pause).toHaveBeenCalled(); + expect(result.current.status).toBe('paused'); + }); + + test('resumePlayback toggles paused→playing and invokes native play', async () => { + mockUseAudioPlayerStatus.mockReturnValue({ + isLoaded: true, + playing: false, + currentTime: 0.1, + duration: 10, + }); + const pause = jest.fn(); + const play = jest.fn(); + mockUseAudioPlayer.mockReturnValue({ pause, play }); + + const { result } = renderHook(() => usePlayer()); + + await act(async () => { + await result.current.enqueueExclusive({ + uri: 'file:///fixture.mp3', + kind: 'voice', + }); + }); + + act(() => { + result.current.pausePlayback(); + }); + expect(result.current.status).toBe('paused'); + + await act(async () => { + await result.current.resumePlayback(); + }); + + expect(play).toHaveBeenCalled(); + expect(result.current.status).toBe('playing'); + }); + + test('pausePlayback is a no-op while idle', async () => { + const pause = jest.fn(); + mockUseAudioPlayer.mockReturnValue({ pause, play: jest.fn() }); + + const { result } = renderHook(() => usePlayer()); + + act(() => { + result.current.pausePlayback(); + }); + + expect(pause).not.toHaveBeenCalled(); + expect(result.current.status).toBe('idle'); + }); });