From b0ea6d9e44d424cc01ddfecb93835d70dd8720a2 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:40:30 -0700 Subject: [PATCH] Support `api/embed` (#208) * api/embed * api/embed * api/embed * rm legacy --- ollama/__init__.py | 2 ++ ollama/_client.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/ollama/__init__.py b/ollama/__init__.py index 80b0858..c452f71 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -21,6 +21,7 @@ __all__ = [ 'ResponseError', 'generate', 'chat', + 'embed', 'embeddings', 'pull', 'push', @@ -36,6 +37,7 @@ _client = Client() generate = _client.generate chat = _client.chat +embed = _client.embed embeddings = _client.embeddings pull = _client.pull push = _client.push diff --git a/ollama/_client.py b/ollama/_client.py index e640a34..e991092 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -243,6 +243,29 @@ class Client(BaseClient): stream=stream, ) + def embed( + self, + model: str = '', + input: Union[str, Sequence[AnyStr]] = '', + truncate: bool = True, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: + if not model: + raise RequestError('must provide a model') + + return self._request( + 'POST', + '/api/embed', + json={ + 'model': model, + 'input': input, + 'truncate': truncate, + 'options': options or {}, + 'keep_alive': keep_alive, + }, + ).json() + def embeddings( self, model: str = '', @@ -634,6 +657,31 @@ class AsyncClient(BaseClient): stream=stream, ) + async def embed( + self, + model: str = '', + input: Union[str, Sequence[AnyStr]] = '', + truncate: bool = True, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: + if not model: + raise RequestError('must provide a model') + + response = await self._request( + 'POST', + '/api/embed', + json={ + 'model': model, + 'input': input, + 'truncate': truncate, + 'options': options or {}, + 'keep_alive': keep_alive, + }, + ) + + return response.json() + async def embeddings( self, model: str = '',