mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
feat: server-side OAuth token management system
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
247
backend/open_webui/models/oauth_sessions.py
Normal file
247
backend/open_webui/models/oauth_sessions.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, Index
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# DB MODEL
|
||||
####################
|
||||
|
||||
|
||||
class OAuthSession(Base):
|
||||
__tablename__ = "oauth_session"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
provider = Column(Text, nullable=False)
|
||||
token = Column(
|
||||
Text, nullable=False
|
||||
) # JSON with access_token, id_token, refresh_token
|
||||
expires_at = Column(BigInteger, nullable=False)
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
# Add indexes for better performance
|
||||
__table_args__ = (
|
||||
Index("idx_oauth_session_user_id", "user_id"),
|
||||
Index("idx_oauth_session_expires_at", "expires_at"),
|
||||
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
||||
)
|
||||
|
||||
|
||||
class OAuthSessionModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
token: dict
|
||||
expires_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class OAuthSessionResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
class OAuthSessionTable:
|
||||
def __init__(self):
|
||||
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
if not self.encryption_key:
|
||||
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
||||
|
||||
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
||||
if len(self.encryption_key) != 44:
|
||||
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
|
||||
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
|
||||
else:
|
||||
self.encryption_key = self.encryption_key.encode()
|
||||
|
||||
try:
|
||||
self.fernet = Fernet(self.encryption_key)
|
||||
except Exception as e:
|
||||
log.error(f"Error initializing Fernet with provided key: {e}")
|
||||
raise
|
||||
|
||||
def _encrypt_token(self, token) -> str:
|
||||
"""Encrypt OAuth tokens for storage"""
|
||||
try:
|
||||
token_json = json.dumps(token)
|
||||
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
||||
return encrypted
|
||||
except Exception as e:
|
||||
log.error(f"Error encrypting tokens: {e}")
|
||||
raise
|
||||
|
||||
def _decrypt_token(self, token: str):
|
||||
"""Decrypt OAuth tokens from storage"""
|
||||
try:
|
||||
decrypted = self.fernet.decrypt(token.encode()).decode()
|
||||
return json.loads(decrypted)
|
||||
except Exception as e:
|
||||
log.error(f"Error decrypting tokens: {e}")
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
token: dict,
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Create a new OAuth session"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
current_time = int(time.time())
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
result = OAuthSession(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
if result:
|
||||
result.token = token # Return decrypted token
|
||||
return OAuthSessionModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error creating OAuth session: {e}")
|
||||
return None
|
||||
|
||||
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
return OAuthSessionModel.model_validate(session)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
return None
|
||||
|
||||
def get_session_by_id_and_user_id(
|
||||
self, session_id: str, user_id: str
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID and user ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
session = (
|
||||
db.query(OAuthSession)
|
||||
.filter_by(id=session_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
return OAuthSessionModel.model_validate(session)
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
return None
|
||||
|
||||
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
||||
"""Get all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
|
||||
|
||||
|
||||
results = []
|
||||
for session in sessions:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
results.append(OAuthSessionModel.model_validate(session))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
||||
return []
|
||||
|
||||
def update_session_by_id(
|
||||
self, session_id: str, token: dict
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Update OAuth session tokens"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
current_time = int(time.time())
|
||||
|
||||
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||
{
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
return OAuthSessionModel.model_validate(session)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error updating OAuth session tokens: {e}")
|
||||
return None
|
||||
|
||||
def delete_session_by_id(self, session_id: str) -> bool:
|
||||
"""Delete an OAuth session"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(OAuthSession).filter_by(id=session_id).delete()
|
||||
db.commit()
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth session: {e}")
|
||||
return False
|
||||
|
||||
def delete_sessions_by_user_id(self, user_id: str) -> bool:
|
||||
"""Delete all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||
return False
|
||||
|
||||
|
||||
OAuthSessions = OAuthSessionTable()
|
||||
Reference in New Issue
Block a user