mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
459 lines
16 KiB
Python
459 lines
16 KiB
Python
import copy
|
||
import time
|
||
import logging
|
||
import asyncio
|
||
import sys
|
||
|
||
from aiocache import cached
|
||
from fastapi import Request
|
||
|
||
from open_webui.socket.utils import RedisDict
|
||
from open_webui.routers import openai, ollama
|
||
from open_webui.functions import get_function_models
|
||
|
||
|
||
from open_webui.models.functions import Functions
|
||
from open_webui.models.models import Models
|
||
from open_webui.models.access_grants import AccessGrants
|
||
from open_webui.models.groups import Groups
|
||
|
||
|
||
from open_webui.utils.plugin import (
|
||
load_function_module_by_id,
|
||
get_function_module_from_cache,
|
||
)
|
||
from open_webui.utils.access_control import has_access
|
||
|
||
|
||
from open_webui.config import (
|
||
BYPASS_ADMIN_ACCESS_CONTROL,
|
||
DEFAULT_ARENA_MODEL,
|
||
)
|
||
|
||
from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL
|
||
from open_webui.models.users import UserModel
|
||
|
||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||
log = logging.getLogger(__name__)
|
||
|
||
|
||
async def fetch_ollama_models(request: Request, user: UserModel = None):
|
||
raw_ollama_models = await ollama.get_all_models(request, user=user)
|
||
return [
|
||
{
|
||
'id': model['model'],
|
||
'name': model['name'],
|
||
'object': 'model',
|
||
'created': int(time.time()),
|
||
'owned_by': 'ollama',
|
||
'ollama': model,
|
||
'connection_type': model.get('connection_type', 'local'),
|
||
'tags': model.get('tags', []),
|
||
}
|
||
for model in raw_ollama_models['models']
|
||
]
|
||
|
||
|
||
async def fetch_openai_models(request: Request, user: UserModel = None):
|
||
openai_response = await openai.get_all_models(request, user=user)
|
||
return openai_response['data']
|
||
|
||
|
||
async def get_all_base_models(request: Request, user: UserModel = None):
|
||
openai_task = (
|
||
fetch_openai_models(request, user)
|
||
if request.app.state.config.ENABLE_OPENAI_API
|
||
else asyncio.sleep(0, result=[])
|
||
)
|
||
ollama_task = (
|
||
fetch_ollama_models(request, user)
|
||
if request.app.state.config.ENABLE_OLLAMA_API
|
||
else asyncio.sleep(0, result=[])
|
||
)
|
||
function_task = get_function_models(request)
|
||
|
||
openai_models, ollama_models, function_models = await asyncio.gather(openai_task, ollama_task, function_task)
|
||
|
||
return function_models + openai_models + ollama_models
|
||
|
||
|
||
async def get_all_models(request, refresh: bool = False, user: UserModel = None):
|
||
if (
|
||
request.app.state.MODELS
|
||
and request.app.state.BASE_MODELS
|
||
and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
|
||
):
|
||
base_models = request.app.state.BASE_MODELS
|
||
else:
|
||
base_models = await get_all_base_models(request, user=user)
|
||
request.app.state.BASE_MODELS = base_models
|
||
|
||
# deep copy the base models to avoid modifying the original list
|
||
models = [model.copy() for model in base_models]
|
||
|
||
# If there are no models, return an empty list
|
||
if len(models) == 0:
|
||
return []
|
||
|
||
# Add arena models
|
||
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
|
||
arena_models = []
|
||
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
|
||
arena_models = [
|
||
{
|
||
'id': model['id'],
|
||
'name': model['name'],
|
||
'info': {
|
||
'meta': model['meta'],
|
||
},
|
||
'object': 'model',
|
||
'created': int(time.time()),
|
||
'owned_by': 'arena',
|
||
'arena': True,
|
||
}
|
||
for model in request.app.state.config.EVALUATION_ARENA_MODELS
|
||
]
|
||
else:
|
||
# Add default arena model
|
||
arena_models = [
|
||
{
|
||
'id': DEFAULT_ARENA_MODEL['id'],
|
||
'name': DEFAULT_ARENA_MODEL['name'],
|
||
'info': {
|
||
'meta': DEFAULT_ARENA_MODEL['meta'],
|
||
},
|
||
'object': 'model',
|
||
'created': int(time.time()),
|
||
'owned_by': 'arena',
|
||
'arena': True,
|
||
}
|
||
]
|
||
models = models + arena_models
|
||
|
||
global_action_ids = [function.id for function in Functions.get_global_action_functions()]
|
||
enabled_action_ids = [function.id for function in Functions.get_functions_by_type('action', active_only=True)]
|
||
|
||
global_filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||
enabled_filter_ids = [function.id for function in Functions.get_functions_by_type('filter', active_only=True)]
|
||
|
||
custom_models = Models.get_all_models()
|
||
|
||
# Single O(1) lookup: Ollama base names first, then exact IDs (exact wins).
|
||
base_model_lookup = {}
|
||
for model in models:
|
||
if model.get('owned_by') == 'ollama':
|
||
base_model_lookup.setdefault(model['id'].split(':')[0], model)
|
||
base_model_lookup[model['id']] = model
|
||
|
||
existing_ids = {m['id'] for m in models}
|
||
|
||
for custom_model in custom_models:
|
||
if custom_model.base_model_id is None:
|
||
# Override applied directly to a base model (shares the same ID)
|
||
model = base_model_lookup.get(custom_model.id)
|
||
|
||
if model:
|
||
if custom_model.is_active:
|
||
model['name'] = custom_model.name
|
||
model['info'] = custom_model.model_dump()
|
||
|
||
action_ids = []
|
||
filter_ids = []
|
||
|
||
if 'info' in model:
|
||
if 'meta' in model['info']:
|
||
action_ids.extend(model['info']['meta'].get('actionIds', []))
|
||
filter_ids.extend(model['info']['meta'].get('filterIds', []))
|
||
|
||
if 'params' in model['info']:
|
||
del model['info']['params']
|
||
|
||
model['action_ids'] = action_ids
|
||
model['filter_ids'] = filter_ids
|
||
else:
|
||
models.remove(model)
|
||
|
||
elif custom_model.is_active:
|
||
if custom_model.id in existing_ids:
|
||
continue
|
||
|
||
owned_by = 'openai'
|
||
connection_type = None
|
||
pipe = None
|
||
|
||
base_model = base_model_lookup.get(custom_model.base_model_id)
|
||
if base_model is None:
|
||
base_model = base_model_lookup.get(custom_model.base_model_id.split(':')[0])
|
||
if base_model:
|
||
owned_by = base_model.get('owned_by', 'unknown')
|
||
if 'pipe' in base_model:
|
||
pipe = base_model['pipe']
|
||
connection_type = base_model.get('connection_type', None)
|
||
|
||
model = {
|
||
'id': f'{custom_model.id}',
|
||
'name': custom_model.name,
|
||
'object': 'model',
|
||
'created': custom_model.created_at,
|
||
'owned_by': owned_by,
|
||
'connection_type': connection_type,
|
||
'preset': True,
|
||
**({'pipe': pipe} if pipe is not None else {}),
|
||
}
|
||
|
||
info = custom_model.model_dump()
|
||
if 'params' in info:
|
||
# Remove params to avoid exposing sensitive info
|
||
del info['params']
|
||
|
||
model['info'] = info
|
||
|
||
action_ids = []
|
||
filter_ids = []
|
||
|
||
if custom_model.meta:
|
||
meta = custom_model.meta.model_dump()
|
||
|
||
if 'actionIds' in meta:
|
||
action_ids.extend(meta['actionIds'])
|
||
|
||
if 'filterIds' in meta:
|
||
filter_ids.extend(meta['filterIds'])
|
||
|
||
model['action_ids'] = action_ids
|
||
model['filter_ids'] = filter_ids
|
||
|
||
models.append(model)
|
||
|
||
# Process action_ids to get the actions
|
||
def get_action_items_from_module(function, module):
|
||
actions = []
|
||
if hasattr(module, 'actions'):
|
||
actions = module.actions
|
||
return [
|
||
{
|
||
'id': f'{function.id}.{action["id"]}',
|
||
'name': action.get('name', f'{function.name} ({action["id"]})'),
|
||
'description': function.meta.description,
|
||
'icon': action.get(
|
||
'icon_url',
|
||
function.meta.manifest.get('icon_url', None)
|
||
or getattr(module, 'icon_url', None)
|
||
or getattr(module, 'icon', None),
|
||
),
|
||
}
|
||
for action in actions
|
||
]
|
||
else:
|
||
return [
|
||
{
|
||
'id': function.id,
|
||
'name': function.name,
|
||
'description': function.meta.description,
|
||
'icon': function.meta.manifest.get('icon_url', None)
|
||
or getattr(module, 'icon_url', None)
|
||
or getattr(module, 'icon', None),
|
||
}
|
||
]
|
||
|
||
# Process filter_ids to get the filters
|
||
def get_filter_items_from_module(function, module):
|
||
return [
|
||
{
|
||
'id': function.id,
|
||
'name': function.name,
|
||
'description': function.meta.description,
|
||
'icon': function.meta.manifest.get('icon_url', None)
|
||
or getattr(module, 'icon_url', None)
|
||
or getattr(module, 'icon', None),
|
||
'has_user_valves': hasattr(module, 'UserValves'),
|
||
}
|
||
]
|
||
|
||
# Batch-prefetch all needed function records to avoid N+1 queries
|
||
all_function_ids = set()
|
||
for model in models:
|
||
all_function_ids.update(model.get('action_ids', []))
|
||
all_function_ids.update(model.get('filter_ids', []))
|
||
all_function_ids.update(global_action_ids)
|
||
all_function_ids.update(global_filter_ids)
|
||
|
||
functions_by_id = {f.id: f for f in Functions.get_functions_by_ids(list(all_function_ids))}
|
||
|
||
# Pre-warm the function module cache once per unique function ID.
|
||
# This ensures each function's DB freshness check runs exactly once,
|
||
# not once per (model × function) pair.
|
||
for function_id in all_function_ids:
|
||
try:
|
||
get_function_module_from_cache(request, function_id)
|
||
except Exception as e:
|
||
log.info(f'Failed to load function module for {function_id}: {e}')
|
||
|
||
# Apply global model defaults to all models
|
||
# Per-model overrides take precedence over global defaults
|
||
default_metadata = getattr(request.app.state.config, 'DEFAULT_MODEL_METADATA', None) or {}
|
||
|
||
if default_metadata:
|
||
for model in models:
|
||
info = model.get('info')
|
||
|
||
if info is None:
|
||
model['info'] = {'meta': copy.deepcopy(default_metadata)}
|
||
continue
|
||
|
||
meta = info.setdefault('meta', {})
|
||
for key, value in default_metadata.items():
|
||
if key == 'capabilities':
|
||
# Merge capabilities: defaults as base, per-model overrides win
|
||
existing = meta.get('capabilities') or {}
|
||
meta['capabilities'] = {**value, **existing}
|
||
elif meta.get(key) is None:
|
||
meta[key] = copy.deepcopy(value)
|
||
|
||
# Batch-fetch all function valves in one query to avoid N+1 DB hits
|
||
# inside get_action_priority (previously called per action × per model).
|
||
all_function_valves = Functions.get_function_valves_by_ids(list(all_function_ids))
|
||
|
||
def get_action_priority(action_id):
|
||
try:
|
||
function_module = request.app.state.FUNCTIONS.get(action_id)
|
||
if function_module and hasattr(function_module, 'Valves'):
|
||
valves_db = all_function_valves.get(action_id)
|
||
valves = function_module.Valves(**(valves_db if valves_db else {}))
|
||
return getattr(valves, 'priority', 0)
|
||
except Exception:
|
||
pass
|
||
return 0
|
||
|
||
for model in models:
|
||
action_ids = [
|
||
action_id
|
||
for action_id in list(set(model.pop('action_ids', []) + global_action_ids))
|
||
if action_id in enabled_action_ids
|
||
]
|
||
action_ids.sort(key=lambda aid: (get_action_priority(aid), aid))
|
||
|
||
filter_ids = [
|
||
filter_id
|
||
for filter_id in list(set(model.pop('filter_ids', []) + global_filter_ids))
|
||
if filter_id in enabled_filter_ids
|
||
]
|
||
|
||
model['actions'] = []
|
||
for action_id in action_ids:
|
||
action_function = functions_by_id.get(action_id)
|
||
if action_function is None:
|
||
log.info(f'Action not found: {action_id}')
|
||
continue
|
||
|
||
function_module = request.app.state.FUNCTIONS.get(action_id)
|
||
if function_module is None:
|
||
log.info(f'Failed to load action module: {action_id}')
|
||
continue
|
||
model['actions'].extend(get_action_items_from_module(action_function, function_module))
|
||
|
||
model['filters'] = []
|
||
for filter_id in filter_ids:
|
||
filter_function = functions_by_id.get(filter_id)
|
||
if filter_function is None:
|
||
log.info(f'Filter not found: {filter_id}')
|
||
continue
|
||
|
||
function_module = request.app.state.FUNCTIONS.get(filter_id)
|
||
if function_module is None:
|
||
log.info(f'Failed to load filter module: {filter_id}')
|
||
continue
|
||
if getattr(function_module, 'toggle', None):
|
||
model['filters'].extend(get_filter_items_from_module(filter_function, function_module))
|
||
|
||
log.debug(f'get_all_models() returned {len(models)} models')
|
||
|
||
models_dict = {model['id']: model for model in models}
|
||
if isinstance(request.app.state.MODELS, RedisDict):
|
||
request.app.state.MODELS.set(models_dict)
|
||
else:
|
||
request.app.state.MODELS = models_dict
|
||
|
||
return models
|
||
|
||
|
||
def check_model_access(user, model, db=None):
|
||
if model.get('arena'):
|
||
meta = model.get('info', {}).get('meta', {})
|
||
access_grants = meta.get('access_grants', [])
|
||
if not has_access(
|
||
user.id,
|
||
permission='read',
|
||
access_grants=access_grants,
|
||
db=db,
|
||
):
|
||
raise Exception('Model not found')
|
||
else:
|
||
model_info = Models.get_model_by_id(model.get('id'), db=db)
|
||
if not model_info:
|
||
raise Exception('Model not found')
|
||
elif not (
|
||
user.id == model_info.user_id
|
||
or AccessGrants.has_access(
|
||
user_id=user.id,
|
||
resource_type='model',
|
||
resource_id=model_info.id,
|
||
permission='read',
|
||
db=db,
|
||
)
|
||
):
|
||
raise Exception('Model not found')
|
||
|
||
|
||
def get_filtered_models(models, user, db=None):
|
||
# Filter out models that the user does not have access to
|
||
if (
|
||
user.role == 'user' or (user.role == 'admin' and not BYPASS_ADMIN_ACCESS_CONTROL)
|
||
) and not BYPASS_MODEL_ACCESS_CONTROL:
|
||
model_infos = {}
|
||
for model in models:
|
||
if model.get('arena'):
|
||
continue
|
||
info = model.get('info')
|
||
if info:
|
||
model_infos[model['id']] = info
|
||
|
||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||
|
||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||
user_id=user.id,
|
||
resource_type='model',
|
||
resource_ids=list(model_infos.keys()),
|
||
permission='read',
|
||
user_group_ids=user_group_ids,
|
||
db=db,
|
||
)
|
||
|
||
filtered_models = []
|
||
for model in models:
|
||
if model.get('arena'):
|
||
meta = model.get('info', {}).get('meta', {})
|
||
access_grants = meta.get('access_grants', [])
|
||
if has_access(
|
||
user.id,
|
||
permission='read',
|
||
access_grants=access_grants,
|
||
user_group_ids=user_group_ids,
|
||
):
|
||
filtered_models.append(model)
|
||
continue
|
||
|
||
model_info = model_infos.get(model['id'])
|
||
if model_info:
|
||
if (
|
||
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
|
||
or user.id == model_info.get('user_id')
|
||
or model['id'] in accessible_model_ids
|
||
):
|
||
filtered_models.append(model)
|
||
|
||
return filtered_models
|
||
else:
|
||
return models
|