"""智谱 BigModel 国内 embedding API — 实现 EmbeddingProvider(zai-sdk / ZhipuAiClient)。""" from __future__ import annotations import asyncio from zai import ZhipuAiClient from app.core.business_telemetry import business_span from app.core.embedding import MEMORY_EMBEDDING_DIMENSION from app.core.logging import get_logger _logger = get_logger(__name__) # 单次请求最多 64 条文本(智谱 Embedding-3 文档) _EMBED_BATCH_SIZE = 64 class ZhipuEmbeddingProvider: def __init__( self, *, api_key: str, base_url: str | None = None, model: str = "embedding-3", ) -> None: self._model = model if not api_key: _logger.warning( "ZhipuEmbeddingProvider: api_key 为空,embedding 将不可用(记忆检索与 ingest 向量写入会降级)" ) self._client = None elif base_url: self._client = ZhipuAiClient( api_key=api_key, base_url=base_url.rstrip("/"), ) else: self._client = ZhipuAiClient(api_key=api_key) def is_available(self) -> bool: return self._client is not None def _create_vectors_sync(self, texts: list[str]) -> list[list[float]]: assert self._client is not None resp = self._client.embeddings.create( input=texts, model=self._model, dimensions=MEMORY_EMBEDDING_DIMENSION, ) ordered = sorted(resp.data, key=lambda d: d.index or 0) return [list(item.embedding) for item in ordered] async def embed_text(self, text: str) -> list[float]: vectors = await self.embed_texts([text]) return vectors[0] if vectors else [] async def embed_texts(self, texts: list[str]) -> list[list[float]]: if not self._client or not texts: return [] with business_span("embedding.zhipu.embed", batch_size=len(texts)): out: list[list[float]] = [] for i in range(0, len(texts), _EMBED_BATCH_SIZE): batch = texts[i : i + _EMBED_BATCH_SIZE] part = await asyncio.to_thread(self._create_vectors_sync, batch) out.extend(part) return out def embed_text_sync(self, text: str) -> list[float]: vecs = self.embed_texts_sync([text]) return vecs[0] if vecs else [] def embed_texts_sync(self, texts: list[str]) -> list[list[float]]: if not self._client or not texts: return [] with business_span("embedding.zhipu.embed", batch_size=len(texts)): out: list[list[float]] = [] for i in range(0, len(texts), _EMBED_BATCH_SIZE): batch = texts[i : i + _EMBED_BATCH_SIZE] out.extend(self._create_vectors_sync(batch)) return out