feat(api): 接入智谱 embedding-3(1024 维)并迁移 memory_chunks 向量列
This commit is contained in:
@@ -33,6 +33,16 @@ DEEPSEEK_API_KEY=your_deepseek_api_key
|
||||
DEEPSEEK_BASE_URL=https://api.deepseek.com
|
||||
DEEPSEEK_MODEL=deepseek-chat
|
||||
|
||||
# =============================================================================
|
||||
# Memory 向量(智谱 BigModel 国内 embedding-3;与 DeepSeek/OpenAI 用途分离)
|
||||
# 文档:https://docs.bigmodel.cn/cn/guide/models/embedding/embedding-3
|
||||
# 本期固定 1024 维;库表经迁移与 MEMORY_EMBEDDING_DIMENSION 一致。
|
||||
# =============================================================================
|
||||
ZHIPU_API_KEY=your_zhipu_api_key
|
||||
# 默认国内通用端点(与 ZhipuAiClient 一致)
|
||||
# EMBEDDING_BASE_URL=https://open.bigmodel.cn/api/paas/v4
|
||||
EMBEDDING_MODEL=embedding-3
|
||||
|
||||
# Chat 访谈:每轮根据用户内容判定主人生阶段(关则仅用关键词,省一次 LLM)
|
||||
# CHAT_STAGE_DETECTION_ENABLED=true
|
||||
# CHAT_STAGE_DETECTION_MAX_TOKENS=128
|
||||
|
||||
@@ -17,6 +17,16 @@ DEEPSEEK_API_KEY=sk-09f17fb61c5a4299a3afc2a01de7af75
|
||||
DEEPSEEK_BASE_URL=https://api.deepseek.com
|
||||
DEEPSEEK_MODEL=deepseek-chat
|
||||
|
||||
# =============================================================================
|
||||
# Memory 向量(智谱 BigModel 国内 embedding-3;与 DeepSeek/OpenAI 用途分离)
|
||||
# 文档:https://docs.bigmodel.cn/cn/guide/models/embedding/embedding-3
|
||||
# 本期固定 1024 维;库表经迁移与 MEMORY_EMBEDDING_DIMENSION 一致。
|
||||
# =============================================================================
|
||||
ZHIPU_API_KEY=524eda18eb3848e881eefe4c7ef17ec2.xBmGUabYDEa44m3M
|
||||
# 默认国内通用端点(与 ZhipuAiClient 一致)
|
||||
# EMBEDDING_BASE_URL=https://open.bigmodel.cn/api/paas/v4
|
||||
EMBEDDING_MODEL=embedding-3
|
||||
|
||||
# =============================================================================
|
||||
# Database
|
||||
# =============================================================================
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""memory_chunks.embedding: vector(1536) -> vector(1024),并清空旧向量。
|
||||
|
||||
与智谱 embedding-3(固定 1024 维)及 ORM 一致;旧 OpenAI 等模型向量不可与智谱混用,故先置 NULL 再改类型。
|
||||
|
||||
Revision ID: 0004_memory_embedding_1024
|
||||
Revises: 0003_timeline_memory_source
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "0004_memory_embedding_1024"
|
||||
down_revision: Union[str, None] = "0003_timeline_memory_source"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(sa.text("UPDATE memory_chunks SET embedding = NULL"))
|
||||
op.execute(
|
||||
sa.text("ALTER TABLE memory_chunks ALTER COLUMN embedding TYPE vector(1024)")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(sa.text("UPDATE memory_chunks SET embedding = NULL"))
|
||||
op.execute(
|
||||
sa.text("ALTER TABLE memory_chunks ALTER COLUMN embedding TYPE vector(1536)")
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
"""OpenAI embedding adapter — implements EmbeddingProvider port."""
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider:
|
||||
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
||||
self._client = AsyncOpenAI(api_key=api_key) if api_key else None
|
||||
self._model = model
|
||||
|
||||
async def embed_text(self, text: str) -> list[float]:
|
||||
if not self._client:
|
||||
return []
|
||||
resp = await self._client.embeddings.create(input=[text], model=self._model)
|
||||
return resp.data[0].embedding
|
||||
|
||||
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||
if not self._client or not texts:
|
||||
return []
|
||||
resp = await self._client.embeddings.create(input=texts, model=self._model)
|
||||
return [item.embedding for item in sorted(resp.data, key=lambda d: d.index)]
|
||||
56
api/app/adapters/embedding/zhipu.py
Normal file
56
api/app/adapters/embedding/zhipu.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""智谱 BigModel 国内 embedding API — 实现 EmbeddingProvider(zai-sdk / ZhipuAiClient)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from zai import ZhipuAiClient
|
||||
|
||||
from app.core.embedding import MEMORY_EMBEDDING_DIMENSION
|
||||
|
||||
# 单次请求最多 64 条文本(智谱 Embedding-3 文档)
|
||||
_EMBED_BATCH_SIZE = 64
|
||||
|
||||
|
||||
class ZhipuEmbeddingProvider:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
model: str = "embedding-3",
|
||||
) -> None:
|
||||
self._model = model
|
||||
if not api_key:
|
||||
self._client = None
|
||||
elif base_url:
|
||||
self._client = ZhipuAiClient(
|
||||
api_key=api_key,
|
||||
base_url=base_url.rstrip("/"),
|
||||
)
|
||||
else:
|
||||
self._client = ZhipuAiClient(api_key=api_key)
|
||||
|
||||
def _create_vectors_sync(self, texts: list[str]) -> list[list[float]]:
|
||||
assert self._client is not None
|
||||
resp = self._client.embeddings.create(
|
||||
input=texts,
|
||||
model=self._model,
|
||||
dimensions=MEMORY_EMBEDDING_DIMENSION,
|
||||
)
|
||||
ordered = sorted(resp.data, key=lambda d: d.index or 0)
|
||||
return [list(item.embedding) for item in ordered]
|
||||
|
||||
async def embed_text(self, text: str) -> list[float]:
|
||||
vectors = await self.embed_texts([text])
|
||||
return vectors[0] if vectors else []
|
||||
|
||||
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||
if not self._client or not texts:
|
||||
return []
|
||||
out: list[list[float]] = []
|
||||
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
|
||||
batch = texts[i : i + _EMBED_BATCH_SIZE]
|
||||
part = await asyncio.to_thread(self._create_vectors_sync, batch)
|
||||
out.extend(part)
|
||||
return out
|
||||
@@ -49,6 +49,11 @@ class Settings(BaseSettings):
|
||||
llm_model: str = ""
|
||||
llm_temperature: float = 0.7
|
||||
|
||||
# ── Memory 向量(智谱 BigModel 国内 embedding-3;与 LLM/DeepSeek 密钥分离)──
|
||||
zhipu_api_key: str = ""
|
||||
embedding_base_url: str = "https://open.bigmodel.cn/api/paas/v4"
|
||||
embedding_model: str = "embedding-3"
|
||||
|
||||
# ── Chat 访谈(短回复:token 上限 + 代码截断,见 reply_limits)──
|
||||
chat_interview_max_tokens: int = 320
|
||||
chat_interview_max_segments: int = 2
|
||||
|
||||
@@ -124,10 +124,13 @@ def get_object_storage() -> ObjectStorage:
|
||||
|
||||
@lru_cache
|
||||
def get_embedding_provider() -> EmbeddingProvider:
|
||||
from app.adapters.embedding.openai import OpenAIEmbeddingProvider
|
||||
from app.adapters.embedding.zhipu import ZhipuEmbeddingProvider
|
||||
|
||||
api_key = settings.openai_api_key or settings.deepseek_api_key
|
||||
return OpenAIEmbeddingProvider(api_key=api_key)
|
||||
return ZhipuEmbeddingProvider(
|
||||
api_key=settings.zhipu_api_key,
|
||||
base_url=settings.embedding_base_url or None,
|
||||
model=settings.embedding_model,
|
||||
)
|
||||
|
||||
|
||||
# ── Auth dependencies ────────────────────────────────────────
|
||||
|
||||
6
api/app/core/embedding.py
Normal file
6
api/app/core/embedding.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Memory chunk 向量维度(与智谱 embedding-3、pgvector 列一致)。
|
||||
|
||||
本期固定 1024;若调整维度需独立迁移与排期,勿仅改此处常量。
|
||||
"""
|
||||
|
||||
MEMORY_EMBEDDING_DIMENSION = 1024
|
||||
@@ -10,12 +10,13 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR as TSVector
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.db import Base, utc_now
|
||||
from app.core.embedding import MEMORY_EMBEDDING_DIMENSION
|
||||
|
||||
pgvector_type = Vector(1536)
|
||||
pgvector_type = Vector(MEMORY_EMBEDDING_DIMENSION)
|
||||
|
||||
|
||||
class MemorySource(Base):
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"tencentcloud-sdk-python>=3.1.54",
|
||||
"weasyprint>=68.1",
|
||||
"wechatpayv3>=2.0.2",
|
||||
"zai-sdk>=0.2.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
28
api/uv.lock
generated
28
api/uv.lock
generated
@@ -101,6 +101,7 @@ dependencies = [
|
||||
{ name = "tencentcloud-sdk-python" },
|
||||
{ name = "weasyprint" },
|
||||
{ name = "wechatpayv3" },
|
||||
{ name = "zai-sdk" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
@@ -140,6 +141,7 @@ requires-dist = [
|
||||
{ name = "tencentcloud-sdk-python", specifier = ">=3.1.54" },
|
||||
{ name = "weasyprint", specifier = ">=68.1" },
|
||||
{ name = "wechatpayv3", specifier = ">=2.0.2" },
|
||||
{ name = "zai-sdk", specifier = ">=0.2.2" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
@@ -312,6 +314,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/ef/e7e485ce5e4ba3843a0a92feb767c7b6098fd6e65ce752918074d175ae71/brotlicffi-1.2.0.1-cp38-abi3-win_amd64.whl", hash = "sha256:da2e82a08e7778b8bc539d27ca03cdd684113e81394bfaaad8d0dfc6a17ddede", size = 379026, upload-time = "2026-03-05T19:54:04.322Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "7.0.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/dd/57fe3fdb6e65b25a5987fd2cdc7e22db0aef508b91634d2e57d22928d41b/cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990", size = 37367, upload-time = "2026-03-09T20:51:29.451Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "celery"
|
||||
version = "5.6.2"
|
||||
@@ -2924,6 +2935,23 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/c9/7243eb3f9eaabd1a88a5a5acadf06df2d83b100c62684b7425c6a11bcaa8/xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95", size = 28898, upload-time = "2025-10-02T14:36:17.843Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zai-sdk"
|
||||
version = "0.2.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cachetools" },
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-core" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "sniffio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/a1/0f392da4afd53977383c03ad5c4cc4079b958fb4acd9f8bdeb8fa1e2e5e5/zai_sdk-0.2.2.tar.gz", hash = "sha256:841af188586c41f0f98abcad17b162dc97461c3ab8b3d09f93b0333ebdfe72c3", size = 76605, upload-time = "2026-02-02T16:10:28.502Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/02/d586e861e711f8449c567b01b004d7c5f65ee07c75709992a84e3ef9c40e/zai_sdk-0.2.2-py3-none-any.whl", hash = "sha256:675e9ff3fc2a86e38631331469b935589ae60492127f2c2822c7555214f2dd25", size = 125579, upload-time = "2026-02-02T16:10:26.975Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zopfli"
|
||||
version = "0.4.1"
|
||||
|
||||
Reference in New Issue
Block a user