This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -23,23 +23,21 @@ log = logging.getLogger(__name__)
class OAuthSession(Base):
__tablename__ = "oauth_session"
__tablename__ = 'oauth_session'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False)
token = Column(
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance
__table_args__ = (
Index("idx_oauth_session_user_id", "user_id"),
Index("idx_oauth_session_expires_at", "expires_at"),
Index("idx_oauth_session_user_provider", "user_id", "provider"),
Index('idx_oauth_session_user_id', 'user_id'),
Index('idx_oauth_session_expires_at', 'expires_at'),
Index('idx_oauth_session_user_provider', 'user_id', 'provider'),
)
@@ -71,7 +69,7 @@ class OAuthSessionTable:
def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set')
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44:
@@ -83,7 +81,7 @@ class OAuthSessionTable:
try:
self.fernet = Fernet(self.encryption_key)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
log.error(f'Error initializing Fernet with provided key: {e}')
raise
def _encrypt_token(self, token) -> str:
@@ -93,7 +91,7 @@ class OAuthSessionTable:
encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
log.error(f'Error encrypting tokens: {e}')
raise
def _decrypt_token(self, token: str):
@@ -102,7 +100,7 @@ class OAuthSessionTable:
decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {type(e).__name__}: {e}")
log.error(f'Error decrypting tokens: {type(e).__name__}: {e}')
raise
def create_session(
@@ -120,13 +118,13 @@ class OAuthSessionTable:
result = OAuthSession(
**{
"id": id,
"user_id": user_id,
"provider": provider,
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"created_at": current_time,
"updated_at": current_time,
'id': id,
'user_id': user_id,
'provider': provider,
'token': self._encrypt_token(token),
'expires_at': token.get('expires_at'),
'created_at': current_time,
'updated_at': current_time,
}
)
@@ -141,12 +139,10 @@ class OAuthSessionTable:
else:
return None
except Exception as e:
log.error(f"Error creating OAuth session: {e}")
log.error(f'Error creating OAuth session: {e}')
return None
def get_session_by_id(
self, session_id: str, db: Optional[Session] = None
) -> Optional[OAuthSessionModel]:
def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID"""
try:
with get_db_context(db) as db:
@@ -158,7 +154,7 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
log.error(f'Error getting OAuth session by ID: {e}')
return None
def get_session_by_id_and_user_id(
@@ -167,11 +163,7 @@ class OAuthSessionTable:
"""Get OAuth session by ID and user ID"""
try:
with get_db_context(db) as db:
session = (
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first()
if session:
db.expunge(session)
session.token = self._decrypt_token(session.token)
@@ -179,7 +171,7 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
log.error(f'Error getting OAuth session by ID: {e}')
return None
def get_session_by_provider_and_user_id(
@@ -201,12 +193,10 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
log.error(f'Error getting OAuth session by provider and user ID: {e}')
return None
def get_sessions_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> List[OAuthSessionModel]:
def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
with get_db_context(db) as db:
@@ -220,7 +210,7 @@ class OAuthSessionTable:
results.append(OAuthSessionModel.model_validate(session))
except Exception as e:
log.warning(
f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}"
f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
)
db.query(OAuthSession).filter_by(id=session.id).delete()
db.commit()
@@ -228,7 +218,7 @@ class OAuthSessionTable:
return results
except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}")
log.error(f'Error getting OAuth sessions by user ID: {e}')
return []
def update_session_by_id(
@@ -241,9 +231,9 @@ class OAuthSessionTable:
db.query(OAuthSession).filter_by(id=session_id).update(
{
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"updated_at": current_time,
'token': self._encrypt_token(token),
'expires_at': token.get('expires_at'),
'updated_at': current_time,
}
)
db.commit()
@@ -256,12 +246,10 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}")
log.error(f'Error updating OAuth session tokens: {e}')
return None
def delete_session_by_id(
self, session_id: str, db: Optional[Session] = None
) -> bool:
def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
"""Delete an OAuth session"""
try:
with get_db_context(db) as db:
@@ -269,12 +257,10 @@ class OAuthSessionTable:
db.commit()
return result > 0
except Exception as e:
log.error(f"Error deleting OAuth session: {e}")
log.error(f'Error deleting OAuth session: {e}')
return False
def delete_sessions_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
"""Delete all OAuth sessions for a user"""
try:
with get_db_context(db) as db:
@@ -282,12 +268,10 @@ class OAuthSessionTable:
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
log.error(f'Error deleting OAuth sessions by user ID: {e}')
return False
def delete_sessions_by_provider(
self, provider: str, db: Optional[Session] = None
) -> bool:
def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db_context(db) as db:
@@ -295,7 +279,7 @@ class OAuthSessionTable:
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
log.error(f'Error deleting OAuth sessions by provider {provider}: {e}')
return False