2026-03-30 13:54:35 +08:00
|
|
|
|
"""智谱 BigModel 国内 embedding API — 实现 EmbeddingProvider(zai-sdk / ZhipuAiClient)。"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
from zai import ZhipuAiClient
|
|
|
|
|
|
|
|
|
|
|
|
from app.core.embedding import MEMORY_EMBEDDING_DIMENSION
|
2026-04-03 11:43:16 +08:00
|
|
|
|
from app.core.logging import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
_logger = get_logger(__name__)
|
2026-03-30 13:54:35 +08:00
|
|
|
|
|
|
|
|
|
|
# 单次请求最多 64 条文本(智谱 Embedding-3 文档)
|
|
|
|
|
|
_EMBED_BATCH_SIZE = 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZhipuEmbeddingProvider:
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
*,
|
|
|
|
|
|
api_key: str,
|
|
|
|
|
|
base_url: str | None = None,
|
|
|
|
|
|
model: str = "embedding-3",
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
self._model = model
|
|
|
|
|
|
if not api_key:
|
2026-04-03 11:43:16 +08:00
|
|
|
|
_logger.warning(
|
|
|
|
|
|
"ZhipuEmbeddingProvider: api_key 为空,embedding 将不可用(记忆检索与 ingest 向量写入会降级)"
|
|
|
|
|
|
)
|
2026-03-30 13:54:35 +08:00
|
|
|
|
self._client = None
|
|
|
|
|
|
elif base_url:
|
|
|
|
|
|
self._client = ZhipuAiClient(
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
base_url=base_url.rstrip("/"),
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self._client = ZhipuAiClient(api_key=api_key)
|
|
|
|
|
|
|
2026-04-03 11:43:16 +08:00
|
|
|
|
def is_available(self) -> bool:
|
|
|
|
|
|
return self._client is not None
|
|
|
|
|
|
|
2026-03-30 13:54:35 +08:00
|
|
|
|
def _create_vectors_sync(self, texts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
assert self._client is not None
|
|
|
|
|
|
resp = self._client.embeddings.create(
|
|
|
|
|
|
input=texts,
|
|
|
|
|
|
model=self._model,
|
|
|
|
|
|
dimensions=MEMORY_EMBEDDING_DIMENSION,
|
|
|
|
|
|
)
|
|
|
|
|
|
ordered = sorted(resp.data, key=lambda d: d.index or 0)
|
|
|
|
|
|
return [list(item.embedding) for item in ordered]
|
|
|
|
|
|
|
|
|
|
|
|
async def embed_text(self, text: str) -> list[float]:
|
|
|
|
|
|
vectors = await self.embed_texts([text])
|
|
|
|
|
|
return vectors[0] if vectors else []
|
|
|
|
|
|
|
|
|
|
|
|
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
if not self._client or not texts:
|
|
|
|
|
|
return []
|
|
|
|
|
|
out: list[list[float]] = []
|
|
|
|
|
|
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
|
|
|
|
|
|
batch = texts[i : i + _EMBED_BATCH_SIZE]
|
|
|
|
|
|
part = await asyncio.to_thread(self._create_vectors_sync, batch)
|
|
|
|
|
|
out.extend(part)
|
|
|
|
|
|
return out
|
2026-04-03 11:43:16 +08:00
|
|
|
|
|
|
|
|
|
|
def embed_text_sync(self, text: str) -> list[float]:
|
|
|
|
|
|
vecs = self.embed_texts_sync([text])
|
|
|
|
|
|
return vecs[0] if vecs else []
|
|
|
|
|
|
|
|
|
|
|
|
def embed_texts_sync(self, texts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
if not self._client or not texts:
|
|
|
|
|
|
return []
|
|
|
|
|
|
out: list[list[float]] = []
|
|
|
|
|
|
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
|
|
|
|
|
|
batch = texts[i : i + _EMBED_BATCH_SIZE]
|
|
|
|
|
|
out.extend(self._create_vectors_sync(batch))
|
|
|
|
|
|
return out
|