mirror of
https://github.com/kharonsec/ollama-python
synced 2026-05-10 09:02:49 +02:00
fix: type hints
This commit is contained in:
@@ -34,9 +34,6 @@ class Client(BaseClient):
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
|
||||
return self._request(method, url, **kwargs).json()
|
||||
|
||||
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
|
||||
with self._client.stream(method, url, **kwargs) as r:
|
||||
for line in r.iter_lines():
|
||||
@@ -45,6 +42,14 @@ class Client(BaseClient):
|
||||
raise Exception(e)
|
||||
yield partial
|
||||
|
||||
def _request_stream(
|
||||
self,
|
||||
*args,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
@@ -61,8 +66,7 @@ class Client(BaseClient):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/generate',
|
||||
json={
|
||||
@@ -77,6 +81,7 @@ class Client(BaseClient):
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def chat(
|
||||
@@ -100,8 +105,7 @@ class Client(BaseClient):
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json={
|
||||
@@ -111,6 +115,7 @@ class Client(BaseClient):
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def pull(
|
||||
@@ -119,8 +124,7 @@ class Client(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/pull',
|
||||
json={
|
||||
@@ -128,6 +132,7 @@ class Client(BaseClient):
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def push(
|
||||
@@ -136,8 +141,7 @@ class Client(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/push',
|
||||
json={
|
||||
@@ -145,6 +149,7 @@ class Client(BaseClient):
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def create(
|
||||
@@ -161,8 +166,7 @@ class Client(BaseClient):
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
return self._request_stream(
|
||||
'POST',
|
||||
'/api/create',
|
||||
json={
|
||||
@@ -170,6 +174,7 @@ class Client(BaseClient):
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
@@ -214,14 +219,14 @@ class Client(BaseClient):
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def list(self) -> Mapping[str, Any]:
|
||||
return self._request_json('GET', '/api/tags').get('models', [])
|
||||
return self._request('GET', '/api/tags').json().get('models', [])
|
||||
|
||||
def copy(self, source: str, target: str) -> Mapping[str, Any]:
|
||||
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def show(self, model: str) -> Mapping[str, Any]:
|
||||
return self._request_json('GET', '/api/show', json={'name': model})
|
||||
return self._request('GET', '/api/show', json={'name': model}).json()
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
@@ -233,10 +238,6 @@ class AsyncClient(BaseClient):
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
|
||||
response = await self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
|
||||
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
|
||||
async def inner():
|
||||
async with self._client.stream(method, url, **kwargs) as r:
|
||||
@@ -248,6 +249,18 @@ class AsyncClient(BaseClient):
|
||||
|
||||
return inner()
|
||||
|
||||
async def _request_stream(
|
||||
self,
|
||||
*args,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if stream:
|
||||
return await self._stream(*args, **kwargs)
|
||||
|
||||
response = await self._request(*args, **kwargs)
|
||||
return response.json()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
@@ -264,8 +277,7 @@ class AsyncClient(BaseClient):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/generate',
|
||||
json={
|
||||
@@ -280,6 +292,7 @@ class AsyncClient(BaseClient):
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def chat(
|
||||
@@ -303,8 +316,7 @@ class AsyncClient(BaseClient):
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json={
|
||||
@@ -314,6 +326,7 @@ class AsyncClient(BaseClient):
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def pull(
|
||||
@@ -322,8 +335,7 @@ class AsyncClient(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/pull',
|
||||
json={
|
||||
@@ -331,6 +343,7 @@ class AsyncClient(BaseClient):
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def push(
|
||||
@@ -339,8 +352,7 @@ class AsyncClient(BaseClient):
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/push',
|
||||
json={
|
||||
@@ -348,6 +360,7 @@ class AsyncClient(BaseClient):
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def create(
|
||||
@@ -364,8 +377,7 @@ class AsyncClient(BaseClient):
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
return await self._request_stream(
|
||||
'POST',
|
||||
'/api/create',
|
||||
json={
|
||||
@@ -373,6 +385,7 @@ class AsyncClient(BaseClient):
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
@@ -424,15 +437,16 @@ class AsyncClient(BaseClient):
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def list(self) -> Mapping[str, Any]:
|
||||
response = await self._request_json('GET', '/api/tags')
|
||||
return response.get('models', [])
|
||||
response = await self._request('GET', '/api/tags')
|
||||
return response.json().get('models', [])
|
||||
|
||||
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
|
||||
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def show(self, model: str) -> Mapping[str, Any]:
|
||||
return await self._request_json('GET', '/api/show', json={'name': model})
|
||||
response = await self._request('GET', '/api/show', json={'name': model})
|
||||
return response.json()
|
||||
|
||||
|
||||
def _encode_image(image) -> str:
|
||||
|
||||
Reference in New Issue
Block a user