mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
264 lines
9.5 KiB
Python
264 lines
9.5 KiB
Python
import inspect
|
|
from urllib.parse import urlparse
|
|
import asyncio
|
|
import time
|
|
|
|
import logging
|
|
|
|
import redis
|
|
|
|
from open_webui.env import (
|
|
REDIS_CLUSTER,
|
|
REDIS_SOCKET_CONNECT_TIMEOUT,
|
|
REDIS_SENTINEL_HOSTS,
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
REDIS_SENTINEL_PORT,
|
|
REDIS_URL,
|
|
REDIS_RECONNECT_DELAY,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
_CONNECTION_CACHE = {}
|
|
|
|
|
|
class SentinelRedisProxy:
|
|
def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
|
|
self._sentinel = sentinel
|
|
self._service = service
|
|
self._kw = kw
|
|
self._async_mode = async_mode
|
|
|
|
def _master(self):
|
|
return self._sentinel.master_for(self._service, **self._kw)
|
|
|
|
def __getattr__(self, item):
|
|
master = self._master()
|
|
orig_attr = getattr(master, item)
|
|
|
|
if not callable(orig_attr):
|
|
return orig_attr
|
|
|
|
FACTORY_METHODS = {'pipeline', 'pubsub', 'monitor', 'client', 'transaction'}
|
|
if item in FACTORY_METHODS:
|
|
return orig_attr
|
|
|
|
if self._async_mode:
|
|
if inspect.isasyncgenfunction(orig_attr):
|
|
|
|
def _wrapped_iter(*args, **kwargs):
|
|
async def _iter():
|
|
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
|
|
try:
|
|
method = getattr(self._master(), item)
|
|
async for value in method(*args, **kwargs):
|
|
yield value
|
|
return
|
|
except (
|
|
redis.exceptions.ConnectionError,
|
|
redis.exceptions.ReadOnlyError,
|
|
) as e:
|
|
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
|
|
log.debug(
|
|
'Redis sentinel fail-over (%s). Retry %s/%s',
|
|
type(e).__name__,
|
|
i + 1,
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
)
|
|
if REDIS_RECONNECT_DELAY:
|
|
time.sleep(REDIS_RECONNECT_DELAY / 1000)
|
|
continue
|
|
log.error(
|
|
'Redis operation failed after %s retries: %s',
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
e,
|
|
)
|
|
raise e from e
|
|
|
|
return _iter()
|
|
|
|
return _wrapped_iter
|
|
|
|
async def _wrapped(*args, **kwargs):
|
|
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
|
|
try:
|
|
method = getattr(self._master(), item)
|
|
result = method(*args, **kwargs)
|
|
if inspect.iscoroutine(result):
|
|
return await result
|
|
return result
|
|
except (
|
|
redis.exceptions.ConnectionError,
|
|
redis.exceptions.ReadOnlyError,
|
|
) as e:
|
|
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
|
|
log.debug(
|
|
'Redis sentinel fail-over (%s). Retry %s/%s',
|
|
type(e).__name__,
|
|
i + 1,
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
)
|
|
if REDIS_RECONNECT_DELAY:
|
|
await asyncio.sleep(REDIS_RECONNECT_DELAY / 1000)
|
|
continue
|
|
log.error(
|
|
'Redis operation failed after %s retries: %s',
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
e,
|
|
)
|
|
raise e from e
|
|
|
|
return _wrapped
|
|
|
|
else:
|
|
|
|
def _wrapped(*args, **kwargs):
|
|
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
|
|
try:
|
|
method = getattr(self._master(), item)
|
|
return method(*args, **kwargs)
|
|
except (
|
|
redis.exceptions.ConnectionError,
|
|
redis.exceptions.ReadOnlyError,
|
|
) as e:
|
|
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
|
|
log.debug(
|
|
'Redis sentinel fail-over (%s). Retry %s/%s',
|
|
type(e).__name__,
|
|
i + 1,
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
)
|
|
if REDIS_RECONNECT_DELAY:
|
|
time.sleep(REDIS_RECONNECT_DELAY / 1000)
|
|
continue
|
|
log.error(
|
|
'Redis operation failed after %s retries: %s',
|
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
|
e,
|
|
)
|
|
raise e from e
|
|
|
|
return _wrapped
|
|
|
|
|
|
def parse_redis_service_url(redis_url):
|
|
parsed_url = urlparse(redis_url)
|
|
if parsed_url.scheme != 'redis' and parsed_url.scheme != 'rediss':
|
|
raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
|
|
|
|
return {
|
|
'username': parsed_url.username or None,
|
|
'password': parsed_url.password or None,
|
|
'service': parsed_url.hostname or 'mymaster',
|
|
'port': parsed_url.port or 6379,
|
|
'db': int(parsed_url.path.lstrip('/') or 0),
|
|
}
|
|
|
|
|
|
def get_redis_client(async_mode=False):
|
|
try:
|
|
return get_redis_connection(
|
|
redis_url=REDIS_URL,
|
|
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
|
redis_cluster=REDIS_CLUSTER,
|
|
async_mode=async_mode,
|
|
)
|
|
except Exception as e:
|
|
log.debug(f'Failed to get Redis client: {e}')
|
|
return None
|
|
|
|
|
|
def get_redis_connection(
|
|
redis_url,
|
|
redis_sentinels,
|
|
redis_cluster=False,
|
|
async_mode=False,
|
|
decode_responses=True,
|
|
):
|
|
cache_key = (
|
|
redis_url,
|
|
tuple(redis_sentinels) if redis_sentinels else (),
|
|
async_mode,
|
|
decode_responses,
|
|
)
|
|
|
|
if cache_key in _CONNECTION_CACHE:
|
|
return _CONNECTION_CACHE[cache_key]
|
|
|
|
connection = None
|
|
|
|
if async_mode:
|
|
import redis.asyncio as redis
|
|
|
|
# If using sentinel in async mode
|
|
if redis_sentinels:
|
|
redis_config = parse_redis_service_url(redis_url)
|
|
sentinel = redis.sentinel.Sentinel(
|
|
redis_sentinels,
|
|
port=redis_config['port'],
|
|
db=redis_config['db'],
|
|
username=redis_config['username'],
|
|
password=redis_config['password'],
|
|
decode_responses=decode_responses,
|
|
socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT,
|
|
)
|
|
connection = SentinelRedisProxy(
|
|
sentinel,
|
|
redis_config['service'],
|
|
async_mode=async_mode,
|
|
)
|
|
elif redis_cluster:
|
|
if not redis_url:
|
|
raise ValueError('Redis URL must be provided for cluster mode.')
|
|
return redis.cluster.RedisCluster.from_url(redis_url, decode_responses=decode_responses)
|
|
elif redis_url:
|
|
connection = redis.from_url(redis_url, decode_responses=decode_responses)
|
|
else:
|
|
import redis
|
|
|
|
if redis_sentinels:
|
|
redis_config = parse_redis_service_url(redis_url)
|
|
sentinel = redis.sentinel.Sentinel(
|
|
redis_sentinels,
|
|
port=redis_config['port'],
|
|
db=redis_config['db'],
|
|
username=redis_config['username'],
|
|
password=redis_config['password'],
|
|
decode_responses=decode_responses,
|
|
socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT,
|
|
)
|
|
connection = SentinelRedisProxy(
|
|
sentinel,
|
|
redis_config['service'],
|
|
async_mode=async_mode,
|
|
)
|
|
elif redis_cluster:
|
|
if not redis_url:
|
|
raise ValueError('Redis URL must be provided for cluster mode.')
|
|
return redis.cluster.RedisCluster.from_url(redis_url, decode_responses=decode_responses)
|
|
elif redis_url:
|
|
connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
|
|
|
_CONNECTION_CACHE[cache_key] = connection
|
|
return connection
|
|
|
|
|
|
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
|
if sentinel_hosts_env:
|
|
sentinel_hosts = sentinel_hosts_env.split(',')
|
|
sentinel_port = int(sentinel_port_env)
|
|
return [(host, sentinel_port) for host in sentinel_hosts]
|
|
return []
|
|
|
|
|
|
def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
|
|
redis_config = parse_redis_service_url(redis_url)
|
|
username = redis_config['username'] or ''
|
|
password = redis_config['password'] or ''
|
|
auth_part = ''
|
|
if username or password:
|
|
auth_part = f'{username}:{password}@'
|
|
hosts_part = ','.join(f'{host}:{sentinel_port_env}' for host in sentinel_hosts_env.split(','))
|
|
return f'redis+sentinel://{auth_part}{hosts_part}/{redis_config["db"]}/{redis_config["service"]}'
|