mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
refac
This commit is contained in:
@@ -56,17 +56,15 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
# db = None
|
||||
try:
|
||||
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
||||
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
|
||||
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
|
||||
db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://'))
|
||||
migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations'
|
||||
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||
router.run()
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize the database connection: {e}")
|
||||
log.warning(
|
||||
"Hint: If your database password contains special characters, you may need to URL-encode it."
|
||||
)
|
||||
log.error(f'Failed to initialize the database connection: {e}')
|
||||
log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.')
|
||||
raise
|
||||
finally:
|
||||
# Properly closing the database connection
|
||||
@@ -74,7 +72,7 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
db.close()
|
||||
|
||||
# Assert if db connection has been closed
|
||||
assert db.is_closed(), "Database connection is still open."
|
||||
assert db.is_closed(), 'Database connection is still open.'
|
||||
|
||||
|
||||
if ENABLE_DB_MIGRATIONS:
|
||||
@@ -84,15 +82,13 @@ if ENABLE_DB_MIGRATIONS:
|
||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||
|
||||
# Handle SQLCipher URLs
|
||||
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||
if not database_password or database_password.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
|
||||
database_password = os.environ.get('DATABASE_PASSWORD')
|
||||
if not database_password or database_password.strip() == '':
|
||||
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||
|
||||
# Extract database path from SQLCipher URL
|
||||
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
||||
db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
|
||||
|
||||
# Create a custom creator function that uses sqlcipher3
|
||||
def create_sqlcipher_connection():
|
||||
@@ -109,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
# or QueuePool if DATABASE_POOL_SIZE is explicitly configured.
|
||||
if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
'sqlite://',
|
||||
creator=create_sqlcipher_connection,
|
||||
pool_size=DATABASE_POOL_SIZE,
|
||||
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
||||
@@ -121,28 +117,26 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
'sqlite://',
|
||||
creator=create_sqlcipher_connection,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||
log.info('Connected to encrypted SQLite database using SQLCipher')
|
||||
|
||||
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
elif 'sqlite' in SQLALCHEMY_DATABASE_URL:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
if DATABASE_ENABLE_SQLITE_WAL:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute('PRAGMA journal_mode=WAL')
|
||||
else:
|
||||
cursor.execute("PRAGMA journal_mode=DELETE")
|
||||
cursor.execute('PRAGMA journal_mode=DELETE')
|
||||
cursor.close()
|
||||
|
||||
event.listen(engine, "connect", on_connect)
|
||||
event.listen(engine, 'connect', on_connect)
|
||||
else:
|
||||
if isinstance(DATABASE_POOL_SIZE, int):
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
@@ -156,16 +150,12 @@ else:
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool)
|
||||
else:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
||||
Base = declarative_base(metadata=metadata_obj)
|
||||
ScopedSession = scoped_session(SessionLocal)
|
||||
|
||||
Reference in New Issue
Block a user