diff --git a/ollama/__init__.py b/ollama/__init__.py index 80b0858..c452f71 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -21,6 +21,7 @@ __all__ = [ 'ResponseError', 'generate', 'chat', + 'embed', 'embeddings', 'pull', 'push', @@ -36,6 +37,7 @@ _client = Client() generate = _client.generate chat = _client.chat +embed = _client.embed embeddings = _client.embeddings pull = _client.pull push = _client.push diff --git a/ollama/_client.py b/ollama/_client.py index e640a34..e991092 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -243,6 +243,29 @@ class Client(BaseClient): stream=stream, ) + def embed( + self, + model: str = '', + input: Union[str, Sequence[AnyStr]] = '', + truncate: bool = True, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: + if not model: + raise RequestError('must provide a model') + + return self._request( + 'POST', + '/api/embed', + json={ + 'model': model, + 'input': input, + 'truncate': truncate, + 'options': options or {}, + 'keep_alive': keep_alive, + }, + ).json() + def embeddings( self, model: str = '', @@ -634,6 +657,31 @@ class AsyncClient(BaseClient): stream=stream, ) + async def embed( + self, + model: str = '', + input: Union[str, Sequence[AnyStr]] = '', + truncate: bool = True, + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Mapping[str, Any]: + if not model: + raise RequestError('must provide a model') + + response = await self._request( + 'POST', + '/api/embed', + json={ + 'model': model, + 'input': input, + 'truncate': truncate, + 'options': options or {}, + 'keep_alive': keep_alive, + }, + ) + + return response.json() + async def embeddings( self, model: str = '',