remove default async client

This commit is contained in:
Michael Yang
2024-01-09 13:04:07 -08:00
parent e2be701953
commit bf56bed816
2 changed files with 23 additions and 32 deletions

View File

@@ -1,4 +1,5 @@
from ollama._client import Client, AsyncClient, Message, Options
from ollama._client import Client, AsyncClient
from ollama._types import Message, Options
__all__ = [
'Client',
@@ -16,26 +17,14 @@ __all__ = [
'show',
]
_default_client = Client()
_client = Client()
generate = _default_client.generate
chat = _default_client.chat
pull = _default_client.pull
push = _default_client.push
create = _default_client.create
delete = _default_client.delete
list = _default_client.list
copy = _default_client.copy
show = _default_client.show
_async_default_client = AsyncClient()
async_generate = _async_default_client.generate
async_chat = _async_default_client.chat
async_pull = _async_default_client.pull
async_push = _async_default_client.push
async_create = _async_default_client.create
async_delete = _async_default_client.delete
async_list = _async_default_client.list
async_copy = _async_default_client.copy
async_show = _async_default_client.show
generate = _client.generate
chat = _client.chat
pull = _client.pull
push = _client.push
create = _client.create
delete = _client.delete
list = _client.list
copy = _client.copy
show = _client.show

View File

@@ -1,3 +1,4 @@
import os
import io
import json
import httpx
@@ -19,12 +20,13 @@ from ollama._types import Message, Options
class BaseClient:
def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None:
def __init__(self, client, base_url: Optional[str] = None) -> None:
base_url = base_url or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434')
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
def __init__(self, base_url: Optional[str] = None) -> None:
super().__init__(httpx.Client, base_url)
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
@@ -38,10 +40,10 @@ class Client(BaseClient):
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
yield part
yield partial
def generate(
self,
@@ -223,7 +225,7 @@ class Client(BaseClient):
class AsyncClient(BaseClient):
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
def __init__(self, base_url: Optional[str] = None) -> None:
super().__init__(httpx.AsyncClient, base_url)
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
@@ -239,10 +241,10 @@ class AsyncClient(BaseClient):
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
async for line in r.aiter_lines():
part = json.loads(line)
if e := part.get('error'):
partial = json.loads(line)
if e := partial.get('error'):
raise Exception(e)
yield part
yield partial
return inner()