mirror of
https://github.com/goauthentik/authentik
synced 2026-05-09 16:42:38 +02:00
* wip Co-authored-by: Jens Langhammer <jens@goauthentik.io> Signed-off-by: Jens Langhammer <jens@goauthentik.io> Signed-off-by: Dominic R <dominic@sdko.org> * remove testing files * a * wip * pls * pls2 * a * Update authentik/providers/proxy/models.py Co-authored-by: Jens L. <jens@beryju.org> Signed-off-by: Dominic R <dominic@sdko.org> * makemigrations * pls * pls1000 * dont migrate in go Signed-off-by: Jens Langhammer <jens@goauthentik.io> * set uuid Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix more test cases Signed-off-by: Jens Langhammer <jens@goauthentik.io> * better logging Signed-off-by: Jens Langhammer <jens@goauthentik.io> * set gorm nowfunc (gorm defaults to local time) Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve test db closing Signed-off-by: Jens Langhammer <jens@goauthentik.io> * move expiration to field Signed-off-by: Jens Langhammer <jens@goauthentik.io> * dont' manually set table Signed-off-by: Jens Langhammer <jens@goauthentik.io> * refactor tests more Signed-off-by: Jens Langhammer <jens@goauthentik.io> * more refactor Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix em Signed-off-by: Jens Langhammer <jens@goauthentik.io> * postgres cleanup is done by worker Signed-off-by: Jens Langhammer <jens@goauthentik.io> * update expiry and set expiring Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io> Signed-off-by: Dominic R <dominic@sdko.org> Co-authored-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens L. <jens@beryju.org>
296 lines
8.2 KiB
Go
296 lines
8.2 KiB
Go
package application
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
|
|
"goauthentik.io/internal/config"
|
|
"goauthentik.io/internal/outpost/proxyv2/constants"
|
|
"goauthentik.io/internal/outpost/proxyv2/postgresstore"
|
|
"goauthentik.io/internal/outpost/proxyv2/types"
|
|
)
|
|
|
|
func SetupTestDB(t *testing.T) *gorm.DB {
|
|
cfg := config.Get().PostgreSQL
|
|
dsn := buildDSN(cfg)
|
|
|
|
gormConfig := &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
NowFunc: func() time.Time {
|
|
return time.Now().UTC()
|
|
},
|
|
}
|
|
|
|
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
|
require.NoError(t, err)
|
|
|
|
return db
|
|
}
|
|
|
|
func CleanupTestDB(t *testing.T, db *gorm.DB) {
|
|
assert.NoError(t, db.Exec("DELETE FROM authentik_providers_proxy_proxysession").Error)
|
|
sdb, err := db.DB()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, sdb.Close())
|
|
}
|
|
|
|
func buildDSN(cfg config.PostgreSQLConfig) string {
|
|
dsn, err := postgresstore.BuildDSN(cfg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return dsn
|
|
}
|
|
|
|
func NewTestStore(db *gorm.DB) *postgresstore.PostgresStore {
|
|
return postgresstore.NewTestStore(db)
|
|
}
|
|
|
|
func TestPostgresStore_SessionLifecycle(t *testing.T) {
|
|
db := SetupTestDB(t)
|
|
defer CleanupTestDB(t, db)
|
|
|
|
// Create sessions directly in the database for testing
|
|
userID := uuid.New()
|
|
sessionKey := "test_session_" + uuid.New().String()
|
|
|
|
sessionData := map[string]interface{}{
|
|
constants.SessionClaims: map[string]interface{}{
|
|
"sub": userID.String(),
|
|
"email": "test@example.com",
|
|
"preferred_username": "testuser",
|
|
"custom_claim": "custom_value",
|
|
"groups": []interface{}{"admin", "user"},
|
|
},
|
|
}
|
|
sessionDataJSON, err := json.Marshal(sessionData)
|
|
require.NoError(t, err)
|
|
|
|
session := postgresstore.ProxySession{
|
|
UUID: uuid.New(),
|
|
SessionKey: sessionKey,
|
|
UserID: &userID,
|
|
SessionData: string(sessionDataJSON),
|
|
Expires: time.Now().Add(time.Hour),
|
|
}
|
|
|
|
err = db.Create(&session).Error
|
|
require.NoError(t, err)
|
|
|
|
// Verify session was created
|
|
var count int64
|
|
db.Model(&postgresstore.ProxySession{}).Where("session_key = ?", sessionKey).Count(&count)
|
|
assert.Equal(t, int64(1), count)
|
|
|
|
// Verify session data
|
|
var retrievedSession postgresstore.ProxySession
|
|
err = db.First(&retrievedSession, "session_key = ?", sessionKey).Error
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, userID, *retrievedSession.UserID)
|
|
|
|
// Parse session data
|
|
var parsedData map[string]interface{}
|
|
err = json.Unmarshal([]byte(retrievedSession.SessionData), &parsedData)
|
|
require.NoError(t, err)
|
|
|
|
claims, ok := parsedData[constants.SessionClaims].(map[string]interface{})
|
|
assert.True(t, ok)
|
|
assert.Equal(t, "test@example.com", claims["email"])
|
|
assert.Equal(t, "testuser", claims["preferred_username"])
|
|
assert.Equal(t, "custom_value", claims["custom_claim"])
|
|
}
|
|
|
|
func TestPostgresStore_LogoutSessions(t *testing.T) {
|
|
db := SetupTestDB(t)
|
|
defer CleanupTestDB(t, db)
|
|
|
|
// Create multiple sessions for different users
|
|
user1 := uuid.New()
|
|
user2 := uuid.New()
|
|
|
|
createSessionData := func(userID uuid.UUID, email string) string {
|
|
sessionData := map[string]interface{}{
|
|
constants.SessionClaims: map[string]interface{}{
|
|
"sub": userID.String(),
|
|
"email": email,
|
|
},
|
|
}
|
|
sessionDataJSON, _ := json.Marshal(sessionData)
|
|
return string(sessionDataJSON)
|
|
}
|
|
|
|
sessions := []postgresstore.ProxySession{
|
|
{
|
|
UUID: uuid.New(),
|
|
SessionKey: "session_user1_1",
|
|
UserID: &user1,
|
|
SessionData: createSessionData(user1, "user1@example.com"),
|
|
Expires: time.Now().Add(time.Hour),
|
|
},
|
|
{
|
|
UUID: uuid.New(),
|
|
SessionKey: "session_user1_2",
|
|
UserID: &user1,
|
|
SessionData: createSessionData(user1, "user1@example.com"),
|
|
Expires: time.Now().Add(time.Hour),
|
|
},
|
|
{
|
|
UUID: uuid.New(),
|
|
SessionKey: "session_user2_1",
|
|
UserID: &user2,
|
|
SessionData: createSessionData(user2, "user2@example.com"),
|
|
Expires: time.Now().Add(time.Hour),
|
|
},
|
|
}
|
|
|
|
for _, session := range sessions {
|
|
err := db.Create(&session).Error
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Verify all sessions were created
|
|
var totalCount int64
|
|
db.Model(&postgresstore.ProxySession{}).Count(&totalCount)
|
|
assert.Equal(t, int64(3), totalCount)
|
|
|
|
// Logout user1 sessions using LogoutSessions method
|
|
store := NewTestStore(db)
|
|
err := store.LogoutSessions(context.Background(), func(c types.Claims) bool {
|
|
return c.Sub == user1.String()
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify only user2 session remains
|
|
var remainingCount int64
|
|
db.Model(&postgresstore.ProxySession{}).Count(&remainingCount)
|
|
assert.Equal(t, int64(1), remainingCount)
|
|
|
|
var remainingSession postgresstore.ProxySession
|
|
err = db.First(&remainingSession).Error
|
|
require.NoError(t, err)
|
|
assert.Equal(t, user2, *remainingSession.UserID)
|
|
}
|
|
|
|
func TestPostgresStore_SessionExpiration(t *testing.T) {
|
|
db := SetupTestDB(t)
|
|
defer CleanupTestDB(t, db)
|
|
|
|
// Create expired and valid sessions
|
|
expiredSession := postgresstore.ProxySession{
|
|
UUID: uuid.New(),
|
|
SessionKey: "expired_session",
|
|
SessionData: "{}",
|
|
Expires: time.Now().Add(-time.Hour),
|
|
}
|
|
validSession := postgresstore.ProxySession{
|
|
UUID: uuid.New(),
|
|
SessionKey: "valid_session",
|
|
SessionData: "{}",
|
|
Expires: time.Now().Add(time.Hour),
|
|
}
|
|
|
|
err := db.Create(&expiredSession).Error
|
|
require.NoError(t, err)
|
|
err = db.Create(&validSession).Error
|
|
require.NoError(t, err)
|
|
|
|
// Clean up expired sessions (this is like what CleanupExpiredSessions would do)
|
|
var sessions []postgresstore.ProxySession
|
|
err = db.Find(&sessions).Error
|
|
require.NoError(t, err)
|
|
|
|
var expiredKeys []string
|
|
now := time.Now()
|
|
for _, session := range sessions {
|
|
expTime := session.Expires
|
|
if now.After(expTime) {
|
|
expiredKeys = append(expiredKeys, session.SessionKey)
|
|
}
|
|
}
|
|
|
|
result := db.Delete(&postgresstore.ProxySession{}, "session_key IN ?", expiredKeys)
|
|
require.NoError(t, result.Error)
|
|
assert.Equal(t, int64(1), result.RowsAffected)
|
|
|
|
// Verify only valid session remains
|
|
var count int64
|
|
db.Model(&postgresstore.ProxySession{}).Count(&count)
|
|
assert.Equal(t, int64(1), count)
|
|
|
|
var remaining postgresstore.ProxySession
|
|
err = db.First(&remaining).Error
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "valid_session", remaining.SessionKey)
|
|
}
|
|
|
|
func TestPostgresStore_SessionClaims(t *testing.T) {
|
|
db := SetupTestDB(t)
|
|
defer CleanupTestDB(t, db)
|
|
|
|
// Create session with complex claims
|
|
userID := uuid.New()
|
|
sessionData := map[string]interface{}{
|
|
constants.SessionClaims: map[string]interface{}{
|
|
"sub": userID.String(),
|
|
"email": "test@example.com",
|
|
"preferred_username": "testuser",
|
|
"groups": []interface{}{"admin", "user"},
|
|
"entitlements": []interface{}{"read", "write"},
|
|
"custom_field": "custom_value",
|
|
},
|
|
}
|
|
sessionDataJSON, err := json.Marshal(sessionData)
|
|
require.NoError(t, err)
|
|
|
|
session := postgresstore.ProxySession{
|
|
UUID: uuid.New(),
|
|
SessionKey: "claims_test_session",
|
|
UserID: &userID,
|
|
SessionData: string(sessionDataJSON),
|
|
Expires: time.Now().Add(time.Hour),
|
|
}
|
|
|
|
err = db.Create(&session).Error
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve and verify claims can be parsed
|
|
var retrieved postgresstore.ProxySession
|
|
err = db.First(&retrieved, "session_key = ?", "claims_test_session").Error
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, userID, *retrieved.UserID)
|
|
|
|
// Parse and verify session data
|
|
var parsedData map[string]interface{}
|
|
err = json.Unmarshal([]byte(retrieved.SessionData), &parsedData)
|
|
require.NoError(t, err)
|
|
|
|
claims, ok := parsedData[constants.SessionClaims].(map[string]interface{})
|
|
assert.True(t, ok)
|
|
assert.Equal(t, "test@example.com", claims["email"])
|
|
assert.Equal(t, "testuser", claims["preferred_username"])
|
|
assert.Equal(t, "custom_value", claims["custom_field"])
|
|
|
|
// Verify groups array
|
|
groups, ok := claims["groups"].([]interface{})
|
|
assert.True(t, ok)
|
|
assert.Contains(t, groups, "admin")
|
|
assert.Contains(t, groups, "user")
|
|
|
|
// Verify entitlements array
|
|
entitlements, ok := claims["entitlements"].([]interface{})
|
|
assert.True(t, ok)
|
|
assert.Contains(t, entitlements, "read")
|
|
assert.Contains(t, entitlements, "write")
|
|
}
|