mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
refac
This commit is contained in:
@@ -23,23 +23,21 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthSession(Base):
|
||||
__tablename__ = "oauth_session"
|
||||
__tablename__ = 'oauth_session'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
provider = Column(Text, nullable=False)
|
||||
token = Column(
|
||||
Text, nullable=False
|
||||
) # JSON with access_token, id_token, refresh_token
|
||||
token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token
|
||||
expires_at = Column(BigInteger, nullable=False)
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
# Add indexes for better performance
|
||||
__table_args__ = (
|
||||
Index("idx_oauth_session_user_id", "user_id"),
|
||||
Index("idx_oauth_session_expires_at", "expires_at"),
|
||||
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
||||
Index('idx_oauth_session_user_id', 'user_id'),
|
||||
Index('idx_oauth_session_expires_at', 'expires_at'),
|
||||
Index('idx_oauth_session_user_provider', 'user_id', 'provider'),
|
||||
)
|
||||
|
||||
|
||||
@@ -71,7 +69,7 @@ class OAuthSessionTable:
|
||||
def __init__(self):
|
||||
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
if not self.encryption_key:
|
||||
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
||||
raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set')
|
||||
|
||||
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
||||
if len(self.encryption_key) != 44:
|
||||
@@ -83,7 +81,7 @@ class OAuthSessionTable:
|
||||
try:
|
||||
self.fernet = Fernet(self.encryption_key)
|
||||
except Exception as e:
|
||||
log.error(f"Error initializing Fernet with provided key: {e}")
|
||||
log.error(f'Error initializing Fernet with provided key: {e}')
|
||||
raise
|
||||
|
||||
def _encrypt_token(self, token) -> str:
|
||||
@@ -93,7 +91,7 @@ class OAuthSessionTable:
|
||||
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
||||
return encrypted
|
||||
except Exception as e:
|
||||
log.error(f"Error encrypting tokens: {e}")
|
||||
log.error(f'Error encrypting tokens: {e}')
|
||||
raise
|
||||
|
||||
def _decrypt_token(self, token: str):
|
||||
@@ -102,7 +100,7 @@ class OAuthSessionTable:
|
||||
decrypted = self.fernet.decrypt(token.encode()).decode()
|
||||
return json.loads(decrypted)
|
||||
except Exception as e:
|
||||
log.error(f"Error decrypting tokens: {type(e).__name__}: {e}")
|
||||
log.error(f'Error decrypting tokens: {type(e).__name__}: {e}')
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
@@ -120,13 +118,13 @@ class OAuthSessionTable:
|
||||
|
||||
result = OAuthSession(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
'provider': provider,
|
||||
'token': self._encrypt_token(token),
|
||||
'expires_at': token.get('expires_at'),
|
||||
'created_at': current_time,
|
||||
'updated_at': current_time,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -141,12 +139,10 @@ class OAuthSessionTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error creating OAuth session: {e}")
|
||||
log.error(f'Error creating OAuth session: {e}')
|
||||
return None
|
||||
|
||||
def get_session_by_id(
|
||||
self, session_id: str, db: Optional[Session] = None
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -158,7 +154,7 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
log.error(f'Error getting OAuth session by ID: {e}')
|
||||
return None
|
||||
|
||||
def get_session_by_id_and_user_id(
|
||||
@@ -167,11 +163,7 @@ class OAuthSessionTable:
|
||||
"""Get OAuth session by ID and user ID"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
session = (
|
||||
db.query(OAuthSession)
|
||||
.filter_by(id=session_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first()
|
||||
if session:
|
||||
db.expunge(session)
|
||||
session.token = self._decrypt_token(session.token)
|
||||
@@ -179,7 +171,7 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
log.error(f'Error getting OAuth session by ID: {e}')
|
||||
return None
|
||||
|
||||
def get_session_by_provider_and_user_id(
|
||||
@@ -201,12 +193,10 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by provider and user ID: {e}")
|
||||
log.error(f'Error getting OAuth session by provider and user ID: {e}')
|
||||
return None
|
||||
|
||||
def get_sessions_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> List[OAuthSessionModel]:
|
||||
def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
|
||||
"""Get all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -220,7 +210,7 @@ class OAuthSessionTable:
|
||||
results.append(OAuthSessionModel.model_validate(session))
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}"
|
||||
f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
|
||||
)
|
||||
db.query(OAuthSession).filter_by(id=session.id).delete()
|
||||
db.commit()
|
||||
@@ -228,7 +218,7 @@ class OAuthSessionTable:
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
||||
log.error(f'Error getting OAuth sessions by user ID: {e}')
|
||||
return []
|
||||
|
||||
def update_session_by_id(
|
||||
@@ -241,9 +231,9 @@ class OAuthSessionTable:
|
||||
|
||||
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||
{
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"updated_at": current_time,
|
||||
'token': self._encrypt_token(token),
|
||||
'expires_at': token.get('expires_at'),
|
||||
'updated_at': current_time,
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
@@ -256,12 +246,10 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error updating OAuth session tokens: {e}")
|
||||
log.error(f'Error updating OAuth session tokens: {e}')
|
||||
return None
|
||||
|
||||
def delete_session_by_id(
|
||||
self, session_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete an OAuth session"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -269,12 +257,10 @@ class OAuthSessionTable:
|
||||
db.commit()
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth session: {e}")
|
||||
log.error(f'Error deleting OAuth session: {e}')
|
||||
return False
|
||||
|
||||
def delete_sessions_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -282,12 +268,10 @@ class OAuthSessionTable:
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||
log.error(f'Error deleting OAuth sessions by user ID: {e}')
|
||||
return False
|
||||
|
||||
def delete_sessions_by_provider(
|
||||
self, provider: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all OAuth sessions for a provider"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -295,7 +279,7 @@ class OAuthSessionTable:
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
||||
log.error(f'Error deleting OAuth sessions by provider {provider}: {e}')
|
||||
return False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user