Files
life-echo/api/app/adapters/embedding/zhipu.py

57 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""智谱 BigModel 国内 embedding API — 实现 EmbeddingProviderzai-sdk / ZhipuAiClient"""
from __future__ import annotations
import asyncio
from zai import ZhipuAiClient
from app.core.embedding import MEMORY_EMBEDDING_DIMENSION
# 单次请求最多 64 条文本(智谱 Embedding-3 文档)
_EMBED_BATCH_SIZE = 64
class ZhipuEmbeddingProvider:
def __init__(
self,
*,
api_key: str,
base_url: str | None = None,
model: str = "embedding-3",
) -> None:
self._model = model
if not api_key:
self._client = None
elif base_url:
self._client = ZhipuAiClient(
api_key=api_key,
base_url=base_url.rstrip("/"),
)
else:
self._client = ZhipuAiClient(api_key=api_key)
def _create_vectors_sync(self, texts: list[str]) -> list[list[float]]:
assert self._client is not None
resp = self._client.embeddings.create(
input=texts,
model=self._model,
dimensions=MEMORY_EMBEDDING_DIMENSION,
)
ordered = sorted(resp.data, key=lambda d: d.index or 0)
return [list(item.embedding) for item in ordered]
async def embed_text(self, text: str) -> list[float]:
vectors = await self.embed_texts([text])
return vectors[0] if vectors else []
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
if not self._client or not texts:
return []
out: list[list[float]] = []
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
batch = texts[i : i + _EMBED_BATCH_SIZE]
part = await asyncio.to_thread(self._create_vectors_sync, batch)
out.extend(part)
return out