Files
open-webui/backend/open_webui/utils/models.py
Timothy Jaeryang Baek de3317e26b refac
2026-03-17 17:58:01 -05:00

459 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import time
import logging
import asyncio
import sys
from aiocache import cached
from fastapi import Request
from open_webui.socket.utils import RedisDict
from open_webui.routers import openai, ollama
from open_webui.functions import get_function_models
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.models.access_grants import AccessGrants
from open_webui.models.groups import Groups
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.access_control import has_access
from open_webui.config import (
BYPASS_ADMIN_ACCESS_CONTROL,
DEFAULT_ARENA_MODEL,
)
from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL
from open_webui.models.users import UserModel
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
async def fetch_ollama_models(request: Request, user: UserModel = None):
raw_ollama_models = await ollama.get_all_models(request, user=user)
return [
{
'id': model['model'],
'name': model['name'],
'object': 'model',
'created': int(time.time()),
'owned_by': 'ollama',
'ollama': model,
'connection_type': model.get('connection_type', 'local'),
'tags': model.get('tags', []),
}
for model in raw_ollama_models['models']
]
async def fetch_openai_models(request: Request, user: UserModel = None):
openai_response = await openai.get_all_models(request, user=user)
return openai_response['data']
async def get_all_base_models(request: Request, user: UserModel = None):
openai_task = (
fetch_openai_models(request, user)
if request.app.state.config.ENABLE_OPENAI_API
else asyncio.sleep(0, result=[])
)
ollama_task = (
fetch_ollama_models(request, user)
if request.app.state.config.ENABLE_OLLAMA_API
else asyncio.sleep(0, result=[])
)
function_task = get_function_models(request)
openai_models, ollama_models, function_models = await asyncio.gather(openai_task, ollama_task, function_task)
return function_models + openai_models + ollama_models
async def get_all_models(request, refresh: bool = False, user: UserModel = None):
if (
request.app.state.MODELS
and request.app.state.BASE_MODELS
and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
):
base_models = request.app.state.BASE_MODELS
else:
base_models = await get_all_base_models(request, user=user)
request.app.state.BASE_MODELS = base_models
# deep copy the base models to avoid modifying the original list
models = [model.copy() for model in base_models]
# If there are no models, return an empty list
if len(models) == 0:
return []
# Add arena models
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
'id': model['id'],
'name': model['name'],
'info': {
'meta': model['meta'],
},
'object': 'model',
'created': int(time.time()),
'owned_by': 'arena',
'arena': True,
}
for model in request.app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
'id': DEFAULT_ARENA_MODEL['id'],
'name': DEFAULT_ARENA_MODEL['name'],
'info': {
'meta': DEFAULT_ARENA_MODEL['meta'],
},
'object': 'model',
'created': int(time.time()),
'owned_by': 'arena',
'arena': True,
}
]
models = models + arena_models
global_action_ids = [function.id for function in Functions.get_global_action_functions()]
enabled_action_ids = [function.id for function in Functions.get_functions_by_type('action', active_only=True)]
global_filter_ids = [function.id for function in Functions.get_global_filter_functions()]
enabled_filter_ids = [function.id for function in Functions.get_functions_by_type('filter', active_only=True)]
custom_models = Models.get_all_models()
# Single O(1) lookup: Ollama base names first, then exact IDs (exact wins).
base_model_lookup = {}
for model in models:
if model.get('owned_by') == 'ollama':
base_model_lookup.setdefault(model['id'].split(':')[0], model)
base_model_lookup[model['id']] = model
existing_ids = {m['id'] for m in models}
for custom_model in custom_models:
if custom_model.base_model_id is None:
# Override applied directly to a base model (shares the same ID)
model = base_model_lookup.get(custom_model.id)
if model:
if custom_model.is_active:
model['name'] = custom_model.name
model['info'] = custom_model.model_dump()
action_ids = []
filter_ids = []
if 'info' in model:
if 'meta' in model['info']:
action_ids.extend(model['info']['meta'].get('actionIds', []))
filter_ids.extend(model['info']['meta'].get('filterIds', []))
if 'params' in model['info']:
del model['info']['params']
model['action_ids'] = action_ids
model['filter_ids'] = filter_ids
else:
models.remove(model)
elif custom_model.is_active:
if custom_model.id in existing_ids:
continue
owned_by = 'openai'
connection_type = None
pipe = None
base_model = base_model_lookup.get(custom_model.base_model_id)
if base_model is None:
base_model = base_model_lookup.get(custom_model.base_model_id.split(':')[0])
if base_model:
owned_by = base_model.get('owned_by', 'unknown')
if 'pipe' in base_model:
pipe = base_model['pipe']
connection_type = base_model.get('connection_type', None)
model = {
'id': f'{custom_model.id}',
'name': custom_model.name,
'object': 'model',
'created': custom_model.created_at,
'owned_by': owned_by,
'connection_type': connection_type,
'preset': True,
**({'pipe': pipe} if pipe is not None else {}),
}
info = custom_model.model_dump()
if 'params' in info:
# Remove params to avoid exposing sensitive info
del info['params']
model['info'] = info
action_ids = []
filter_ids = []
if custom_model.meta:
meta = custom_model.meta.model_dump()
if 'actionIds' in meta:
action_ids.extend(meta['actionIds'])
if 'filterIds' in meta:
filter_ids.extend(meta['filterIds'])
model['action_ids'] = action_ids
model['filter_ids'] = filter_ids
models.append(model)
# Process action_ids to get the actions
def get_action_items_from_module(function, module):
actions = []
if hasattr(module, 'actions'):
actions = module.actions
return [
{
'id': f'{function.id}.{action["id"]}',
'name': action.get('name', f'{function.name} ({action["id"]})'),
'description': function.meta.description,
'icon': action.get(
'icon_url',
function.meta.manifest.get('icon_url', None)
or getattr(module, 'icon_url', None)
or getattr(module, 'icon', None),
),
}
for action in actions
]
else:
return [
{
'id': function.id,
'name': function.name,
'description': function.meta.description,
'icon': function.meta.manifest.get('icon_url', None)
or getattr(module, 'icon_url', None)
or getattr(module, 'icon', None),
}
]
# Process filter_ids to get the filters
def get_filter_items_from_module(function, module):
return [
{
'id': function.id,
'name': function.name,
'description': function.meta.description,
'icon': function.meta.manifest.get('icon_url', None)
or getattr(module, 'icon_url', None)
or getattr(module, 'icon', None),
'has_user_valves': hasattr(module, 'UserValves'),
}
]
# Batch-prefetch all needed function records to avoid N+1 queries
all_function_ids = set()
for model in models:
all_function_ids.update(model.get('action_ids', []))
all_function_ids.update(model.get('filter_ids', []))
all_function_ids.update(global_action_ids)
all_function_ids.update(global_filter_ids)
functions_by_id = {f.id: f for f in Functions.get_functions_by_ids(list(all_function_ids))}
# Pre-warm the function module cache once per unique function ID.
# This ensures each function's DB freshness check runs exactly once,
# not once per (model × function) pair.
for function_id in all_function_ids:
try:
get_function_module_from_cache(request, function_id)
except Exception as e:
log.info(f'Failed to load function module for {function_id}: {e}')
# Apply global model defaults to all models
# Per-model overrides take precedence over global defaults
default_metadata = getattr(request.app.state.config, 'DEFAULT_MODEL_METADATA', None) or {}
if default_metadata:
for model in models:
info = model.get('info')
if info is None:
model['info'] = {'meta': copy.deepcopy(default_metadata)}
continue
meta = info.setdefault('meta', {})
for key, value in default_metadata.items():
if key == 'capabilities':
# Merge capabilities: defaults as base, per-model overrides win
existing = meta.get('capabilities') or {}
meta['capabilities'] = {**value, **existing}
elif meta.get(key) is None:
meta[key] = copy.deepcopy(value)
# Batch-fetch all function valves in one query to avoid N+1 DB hits
# inside get_action_priority (previously called per action × per model).
all_function_valves = Functions.get_function_valves_by_ids(list(all_function_ids))
def get_action_priority(action_id):
try:
function_module = request.app.state.FUNCTIONS.get(action_id)
if function_module and hasattr(function_module, 'Valves'):
valves_db = all_function_valves.get(action_id)
valves = function_module.Valves(**(valves_db if valves_db else {}))
return getattr(valves, 'priority', 0)
except Exception:
pass
return 0
for model in models:
action_ids = [
action_id
for action_id in list(set(model.pop('action_ids', []) + global_action_ids))
if action_id in enabled_action_ids
]
action_ids.sort(key=lambda aid: (get_action_priority(aid), aid))
filter_ids = [
filter_id
for filter_id in list(set(model.pop('filter_ids', []) + global_filter_ids))
if filter_id in enabled_filter_ids
]
model['actions'] = []
for action_id in action_ids:
action_function = functions_by_id.get(action_id)
if action_function is None:
log.info(f'Action not found: {action_id}')
continue
function_module = request.app.state.FUNCTIONS.get(action_id)
if function_module is None:
log.info(f'Failed to load action module: {action_id}')
continue
model['actions'].extend(get_action_items_from_module(action_function, function_module))
model['filters'] = []
for filter_id in filter_ids:
filter_function = functions_by_id.get(filter_id)
if filter_function is None:
log.info(f'Filter not found: {filter_id}')
continue
function_module = request.app.state.FUNCTIONS.get(filter_id)
if function_module is None:
log.info(f'Failed to load filter module: {filter_id}')
continue
if getattr(function_module, 'toggle', None):
model['filters'].extend(get_filter_items_from_module(filter_function, function_module))
log.debug(f'get_all_models() returned {len(models)} models')
models_dict = {model['id']: model for model in models}
if isinstance(request.app.state.MODELS, RedisDict):
request.app.state.MODELS.set(models_dict)
else:
request.app.state.MODELS = models_dict
return models
def check_model_access(user, model, db=None):
if model.get('arena'):
meta = model.get('info', {}).get('meta', {})
access_grants = meta.get('access_grants', [])
if not has_access(
user.id,
permission='read',
access_grants=access_grants,
db=db,
):
raise Exception('Model not found')
else:
model_info = Models.get_model_by_id(model.get('id'), db=db)
if not model_info:
raise Exception('Model not found')
elif not (
user.id == model_info.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type='model',
resource_id=model_info.id,
permission='read',
db=db,
)
):
raise Exception('Model not found')
def get_filtered_models(models, user, db=None):
# Filter out models that the user does not have access to
if (
user.role == 'user' or (user.role == 'admin' and not BYPASS_ADMIN_ACCESS_CONTROL)
) and not BYPASS_MODEL_ACCESS_CONTROL:
model_infos = {}
for model in models:
if model.get('arena'):
continue
info = model.get('info')
if info:
model_infos[model['id']] = info
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
user_id=user.id,
resource_type='model',
resource_ids=list(model_infos.keys()),
permission='read',
user_group_ids=user_group_ids,
db=db,
)
filtered_models = []
for model in models:
if model.get('arena'):
meta = model.get('info', {}).get('meta', {})
access_grants = meta.get('access_grants', [])
if has_access(
user.id,
permission='read',
access_grants=access_grants,
user_group_ids=user_group_ids,
):
filtered_models.append(model)
continue
model_info = model_infos.get(model['id'])
if model_info:
if (
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == model_info.get('user_id')
or model['id'] in accessible_model_ids
):
filtered_models.append(model)
return filtered_models
else:
return models