package postgresstore import ( "context" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "math/big" "net/http/httptest" "os" "path/filepath" "reflect" "runtime" "slices" "testing" "time" "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" "gorm.io/gorm/logger" "goauthentik.io/internal/config" "goauthentik.io/internal/outpost/proxyv2/constants" "goauthentik.io/internal/outpost/proxyv2/types" ) // SetupTestDB creates a test database connection for testing func SetupTestDB(t *testing.T) (*gorm.DB, *RefreshableConnPool) { cfg := config.Get().PostgreSQL t.Logf("PostgreSQL config: Host=%s Port=%s User=%s DBName=%s SSLMode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode) t.Logf("Password length: %d", len(cfg.Password)) if cfg.Password == "" { t.Logf("WARNING: Password is empty!") } gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), NowFunc: func() time.Time { return time.Now().UTC() }, } // Use standardized setup db, pool, err := SetupGORMWithRefreshablePool(cfg, gormConfig, 10, 100, time.Hour) require.NoError(t, err) return db, pool } // CleanupTestDB removes test sessions from the database func CleanupTestDB(t *testing.T, db *gorm.DB, pool *RefreshableConnPool) { assert.NoError(t, db.Exec("DELETE FROM authentik_providers_proxy_proxysession").Error) assert.NoError(t, pool.Close()) } func TestPostgresStore_New(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) req := httptest.NewRequest("GET", "/", nil) session, err := store.New(req, "test_session") assert.NoError(t, err) assert.True(t, session.IsNew) assert.Equal(t, "test_session", session.Name()) } func TestPostgresStore_Save(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() session, err := store.New(req, "test_session") require.NoError(t, err) // Set up session claims userID := uuid.New() claims := map[string]any{ "sub": userID.String(), "email": "test@example.com", "preferred_username": "testuser", "exp": time.Now().Add(time.Hour).Unix(), "custom_claim": "custom_value", } session.Values[constants.SessionClaims] = claims err = store.Save(req, w, session) assert.NoError(t, err) // Verify session was saved to database var savedSession ProxySession err = db.First(&savedSession, "session_key = ?", session.ID).Error assert.NoError(t, err) assert.Equal(t, userID, *savedSession.UserID) // Verify session data contains claims var sessionData map[string]any err = json.Unmarshal([]byte(savedSession.SessionData), &sessionData) assert.NoError(t, err) claimsData, ok := sessionData[constants.SessionClaims].(map[string]any) assert.True(t, ok) assert.Equal(t, "test@example.com", claimsData["email"]) assert.Equal(t, "testuser", claimsData["preferred_username"]) assert.Equal(t, "custom_value", claimsData["custom_claim"]) } func TestPostgresStore_Load(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Create a session directly in the database userID := uuid.New() sessionKey := "test_session_123" sessionData := map[string]any{ constants.SessionClaims: map[string]any{ "sub": userID.String(), "email": "test@example.com", "preferred_username": "testuser", "exp": time.Now().Add(time.Hour).Unix(), "custom_claim": "custom_value", }, } sessionDataJSON, err := json.Marshal(sessionData) require.NoError(t, err) proxySession := ProxySession{ UUID: uuid.New(), SessionKey: sessionKey, UserID: &userID, SessionData: string(sessionDataJSON), Expires: time.Now().Add(time.Hour), } err = db.Create(&proxySession).Error require.NoError(t, err) // Load the session session := sessions.NewSession(store, "test_session") session.ID = sessionKey err = store.load(context.Background(), session) assert.NoError(t, err) // Verify claims were loaded correctly claims, ok := session.Values[constants.SessionClaims].(map[string]any) assert.True(t, ok) assert.Equal(t, userID.String(), claims["sub"]) assert.Equal(t, "test@example.com", claims["email"]) assert.Equal(t, "testuser", claims["preferred_username"]) assert.Equal(t, "custom_value", claims["custom_claim"]) } func TestPostgresStore_Delete(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Create a session in the database sessionKey := "test_session_456" proxySession := ProxySession{ UUID: uuid.New(), SessionKey: sessionKey, SessionData: "{}", Expires: time.Now().Add(time.Hour), } err := db.Create(&proxySession).Error require.NoError(t, err) // Delete the session session := sessions.NewSession(store, "test_session") session.ID = sessionKey err = store.delete(context.Background(), session) assert.NoError(t, err) // Verify session was deleted var count int64 db.Model(&ProxySession{}).Where("session_key = ?", sessionKey).Count(&count) assert.Equal(t, int64(0), count) } func TestPostgresStore_LogoutSessions_ByUserID(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Create multiple sessions for different users user1 := uuid.New() user2 := uuid.New() sessions := []ProxySession{ { UUID: uuid.New(), SessionKey: "test_session_user1_1", UserID: &user1, SessionData: createSessionData(t, map[string]any{ "sub": user1.String(), "email": "user1@example.com", }), }, { UUID: uuid.New(), SessionKey: "test_session_user1_2", UserID: &user1, SessionData: createSessionData(t, map[string]any{ "sub": user1.String(), "email": "user1@example.com", }), }, { UUID: uuid.New(), SessionKey: "test_session_user2_1", UserID: &user2, SessionData: createSessionData(t, map[string]any{ "sub": user2.String(), "email": "user2@example.com", }), }, } for _, session := range sessions { err := db.Create(&session).Error require.NoError(t, err) } // Test filtering by user ID ctx := context.Background() err := store.LogoutSessions(ctx, func(c types.Claims) bool { return c.Sub == user1.String() }) assert.NoError(t, err) // Verify only user2 session remains var count int64 db.Model(&ProxySession{}).Where("session_key LIKE 'test_%'").Count(&count) assert.Equal(t, int64(1), count) var remaining ProxySession err = db.Where("session_key LIKE 'test_%'").First(&remaining).Error assert.NoError(t, err) assert.Equal(t, user2, *remaining.UserID) } func TestPostgresStore_LogoutSessions_ByEmail(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Create sessions with different emails sessions := []ProxySession{ { UUID: uuid.New(), SessionKey: "test_session_admin_1", SessionData: createSessionData(t, map[string]any{ "email": "admin@example.com", }), }, { UUID: uuid.New(), SessionKey: "test_session_admin_2", SessionData: createSessionData(t, map[string]any{ "email": "admin@example.com", }), }, { UUID: uuid.New(), SessionKey: "test_session_user_1", SessionData: createSessionData(t, map[string]any{ "email": "user@example.com", }), }, } for _, session := range sessions { err := db.Create(&session).Error require.NoError(t, err) } // Logout all admin sessions ctx := context.Background() err := store.LogoutSessions(ctx, func(c types.Claims) bool { return c.Email == "admin@example.com" }) assert.NoError(t, err) // Verify only user session remains var count int64 db.Model(&ProxySession{}).Where("session_key LIKE 'test_%'").Count(&count) assert.Equal(t, int64(1), count) var remaining ProxySession err = db.Where("session_key LIKE 'test_%'").First(&remaining).Error assert.NoError(t, err) var sessionData map[string]any err = json.Unmarshal([]byte(remaining.SessionData), &sessionData) require.NoError(t, err) claims := sessionData[constants.SessionClaims].(map[string]any) assert.Equal(t, "user@example.com", claims["email"]) } func TestPostgresStore_LogoutSessions_WithGroups(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Create sessions with different group memberships sessions := []ProxySession{ { UUID: uuid.New(), SessionKey: "test_session_admin_user", SessionData: createSessionData(t, map[string]any{ "email": "admin@example.com", "groups": []any{"admin", "user"}, }), }, { UUID: uuid.New(), SessionKey: "test_session_regular_user", SessionData: createSessionData(t, map[string]any{ "email": "user@example.com", "groups": []any{"user"}, }), }, { UUID: uuid.New(), SessionKey: "test_session_guest", SessionData: createSessionData(t, map[string]any{ "email": "guest@example.com", "groups": []any{"guest"}, }), }, } for _, session := range sessions { err := db.Create(&session).Error require.NoError(t, err) } // Logout all sessions that have "admin" group ctx := context.Background() err := store.LogoutSessions(ctx, func(c types.Claims) bool { return slices.Contains(c.Groups, "admin") }) assert.NoError(t, err) // Verify admin user session was removed var count int64 db.Model(&ProxySession{}).Where("session_key LIKE 'test_%'").Count(&count) assert.Equal(t, int64(2), count) // Verify remaining sessions don't have admin group var remainingSessions []ProxySession err = db.Where("session_key LIKE 'test_%'").Find(&remainingSessions).Error assert.NoError(t, err) for _, session := range remainingSessions { var sessionData map[string]any err := json.Unmarshal([]byte(session.SessionData), &sessionData) require.NoError(t, err) claims := sessionData[constants.SessionClaims].(map[string]any) assert.NotEqual(t, "admin@example.com", claims["email"]) } } func TestPostgresStore_LoadExpiredSession(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Create an expired session sessionKey := "test_expired_load" expiredData := map[string]any{ constants.SessionClaims: map[string]any{ "sub": "test-user", }, } expiredDataJSON, _ := json.Marshal(expiredData) proxySession := ProxySession{ UUID: uuid.New(), SessionKey: sessionKey, SessionData: string(expiredDataJSON), Expires: time.Now().Add(-time.Hour), } err := db.Create(&proxySession).Error require.NoError(t, err) // Try to load the expired session session := sessions.NewSession(store, "test_session") session.ID = sessionKey err = store.load(context.Background(), session) // Should return ErrRecordNotFound because session is expired assert.Error(t, err) assert.Equal(t, gorm.ErrRecordNotFound, err) // Verify the expired session was deleted var count int64 db.Model(&ProxySession{}).Where("session_key = ?", sessionKey).Count(&count) assert.Equal(t, int64(0), count) } func TestPostgresStore_ConcurrentSessionAccess(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) // Test concurrent access by creating separate sessions for each goroutine // This tests that the connection pool handles concurrent operations correctly const numGoroutines = 10 done := make(chan error, numGoroutines) for i := range numGoroutines { go func(id int) { // Each goroutine creates its own unique session req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() session, err := store.New(req, "test_session") if err != nil { done <- fmt.Errorf("goroutine %d failed to create session: %w", id, err) return } // Set some data session.Values["goroutine_id"] = id session.Values["timestamp"] = time.Now().Unix() // Save session err = store.Save(req, w, session) if err != nil { done <- fmt.Errorf("goroutine %d failed to save: %w", id, err) return } // Load it back session2, err := store.New(req, "test_session") if err != nil { done <- fmt.Errorf("goroutine %d failed to create session for load: %w", id, err) return } session2.ID = session.ID err = store.load(context.Background(), session2) if err != nil { done <- fmt.Errorf("goroutine %d failed to load: %w", id, err) return } done <- nil }(i) } // Wait for all goroutines to complete for range numGoroutines { err := <-done assert.NoError(t, err) } } func TestBuildDSN_Validation(t *testing.T) { tests := []struct { name string cfg config.PostgreSQLConfig expectError bool errorMsg string }{ { name: "missing host", cfg: config.PostgreSQLConfig{ Port: "5432", User: "testuser", Name: "testdb", }, expectError: true, errorMsg: "PostgreSQL host is required", }, { name: "missing user", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", Name: "testdb", }, expectError: true, errorMsg: "PostgreSQL user is required", }, { name: "missing database name", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", }, expectError: true, errorMsg: "PostgreSQL database name is required", }, { name: "invalid port (zero)", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "0", User: "testuser", Name: "testdb", }, expectError: true, errorMsg: "PostgreSQL port 0 must be positive", }, { name: "invalid port (negative)", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "-1", User: "testuser", Name: "testdb", }, expectError: true, errorMsg: "PostgreSQL port -1 must be positive", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := BuildDSN(tt.cfg) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) assert.Empty(t, result) } else { assert.NoError(t, err) } }) } } func TestBuildConnConfig(t *testing.T) { tests := []struct { name string cfg config.PostgreSQLConfig validate func(*testing.T, *pgx.ConnConfig) }{ { name: "basic configuration", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "localhost", cc.Host) assert.Equal(t, uint16(5432), cc.Port) assert.Equal(t, "testuser", cc.User) assert.Equal(t, "testdb", cc.Database) assert.Equal(t, "", cc.Password) }, }, { name: "with simple password", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: "testpass", Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "testpass", cc.Password) }, }, { name: "with password containing spaces", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: "my secure password", Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "my secure password", cc.Password) }, }, { name: "with password containing single quotes", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: "pass'word", Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "pass'word", cc.Password) }, }, { name: "with password containing backslashes", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: `pass\word`, Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, `pass\word`, cc.Password) }, }, { name: "with password containing special characters", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: `p@ss w0rd!#$%^&*()`, Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, `p@ss w0rd!#$%^&*()`, cc.Password) }, }, { name: "with password containing quotes and backslashes", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: `my'pass\word"here`, Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, `my'pass\word"here`, cc.Password) }, }, { name: "with passphrase (multiple spaces)", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: "the quick brown fox jumps over", Name: "testdb", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "the quick brown fox jumps over", cc.Password) }, }, { name: "with sslmode=disable", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", SSLMode: "disable", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Nil(t, cc.TLSConfig) }, }, { name: "with sslmode=require (no certs)", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", SSLMode: "require", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.NotNil(t, cc.TLSConfig) assert.True(t, cc.TLSConfig.InsecureSkipVerify) }, }, { name: "with custom schema", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", DefaultSchema: "custom_schema", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.NotNil(t, cc.AfterConnect) }, }, { name: "with connection options", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, 10*time.Second, cc.ConnectTimeout) assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, }, { name: "with target_session_attrs", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { // target_session_attrs should NOT be in RuntimeParams _, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") // It should set ValidateConnect instead assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set for target_session_attrs") // Verify it's the correct validator function expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(), runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name()) }, }, { name: "full configuration with special password", cfg: config.PostgreSQLConfig{ Host: "db.example.com", Port: "5433", User: "admin", Password: "my super secret password!@#", Name: "production", SSLMode: "require", DefaultSchema: "app_schema", ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "db.example.com", cc.Host) assert.Equal(t, uint16(5433), cc.Port) assert.Equal(t, "admin", cc.User) assert.Equal(t, "my super secret password!@#", cc.Password) assert.Equal(t, "production", cc.Database) assert.NotNil(t, cc.AfterConnect) assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := BuildConnConfig(tt.cfg) require.NoError(t, err) require.NotNil(t, result) tt.validate(t, result) }) } } // TestBuildConnConfig_WithSSLCertificates tests SSL certificate configuration func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { rootCertPath, clientCertPath, clientKeyPath, cleanup := generateTestCerts(t) defer cleanup() tests := []struct { name string cfg config.PostgreSQLConfig validate func(*testing.T, *pgx.ConnConfig) }{ { name: "verify-full with all certificates", cfg: config.PostgreSQLConfig{ Host: "db.example.com", Port: "5432", User: "testuser", Password: "my secure password", Name: "testdb", SSLMode: "verify-full", SSLRootCert: rootCertPath, SSLCert: clientCertPath, SSLKey: clientKeyPath, }, validate: func(t *testing.T, cc *pgx.ConnConfig) { require.NotNil(t, cc.TLSConfig) assert.False(t, cc.TLSConfig.InsecureSkipVerify) assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName) assert.NotNil(t, cc.TLSConfig.RootCAs) assert.Len(t, cc.TLSConfig.Certificates, 1) }, }, { name: "verify-ca with root cert only", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", SSLMode: "verify-ca", SSLRootCert: rootCertPath, }, validate: func(t *testing.T, cc *pgx.ConnConfig) { require.NotNil(t, cc.TLSConfig) assert.False(t, cc.TLSConfig.InsecureSkipVerify) assert.NotNil(t, cc.TLSConfig.RootCAs) assert.Empty(t, cc.TLSConfig.Certificates) }, }, { name: "require with client cert", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", SSLMode: "require", SSLCert: clientCertPath, SSLKey: clientKeyPath, }, validate: func(t *testing.T, cc *pgx.ConnConfig) { require.NotNil(t, cc.TLSConfig) assert.True(t, cc.TLSConfig.InsecureSkipVerify) assert.Len(t, cc.TLSConfig.Certificates, 1) }, }, { name: "full configuration with SSL and special password", cfg: config.PostgreSQLConfig{ Host: "db.example.com", Port: "5433", User: "admin", Password: "my super secret password!@#", Name: "production", SSLMode: "verify-full", SSLRootCert: rootCertPath, SSLCert: clientCertPath, SSLKey: clientKeyPath, DefaultSchema: "app_schema", ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "db.example.com", cc.Host) assert.Equal(t, uint16(5433), cc.Port) assert.Equal(t, "admin", cc.User) assert.Equal(t, "my super secret password!@#", cc.Password) assert.Equal(t, "production", cc.Database) require.NotNil(t, cc.TLSConfig) assert.False(t, cc.TLSConfig.InsecureSkipVerify) assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName) assert.NotNil(t, cc.TLSConfig.RootCAs) assert.Len(t, cc.TLSConfig.Certificates, 1) assert.NotNil(t, cc.AfterConnect) assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := BuildConnConfig(tt.cfg) require.NoError(t, err) require.NotNil(t, result) tt.validate(t, result) }) } } // TestBuildDSN_WithSpecialPasswords tests that BuildDSN can handle passwords with special characters // by verifying the DSN can actually be used to connect to a database func TestBuildDSN_WithSpecialPasswords(t *testing.T) { tests := []struct { name string password string }{ {"space in password", "my password"}, {"multiple spaces", "the quick brown fox"}, {"single quote", "pass'word"}, {"backslash", `pass\word`}, {"double quote", `pass"word`}, {"special chars", `p@ss!#$%^&*()`}, {"mixed special", `my'pass\word"here`}, {"unicode", "pässwörd"}, {"leading/trailing spaces", " password "}, {"tab character", "pass\tword"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Password: tt.password, Name: "testdb", } // Test that BuildDSN doesn't error dsn, err := BuildDSN(cfg) require.NoError(t, err) require.NotEmpty(t, dsn) // Test that BuildConnConfig preserves the password exactly connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) assert.Equal(t, tt.password, connConfig.Password, "Password should be preserved exactly") }) } } func TestPostgresStore_ConnectionPoolSettings(t *testing.T) { db, pool := SetupTestDB(t) defer CleanupTestDB(t, db, pool) store := NewTestStore(db, pool) sqlDB := pool.GetDB() require.NotNil(t, sqlDB) // Verify connection pool is configured stats := sqlDB.Stats() assert.GreaterOrEqual(t, stats.MaxOpenConnections, 1, "Connection pool should be configured") // Test that we can create multiple sessions concurrently // This indirectly tests connection pool handling const numConcurrentOps = 20 done := make(chan error, numConcurrentOps) for i := range numConcurrentOps { go func(id int) { req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() session, err := store.New(req, "test_session") if err != nil { done <- err return } session.Values["test"] = id err = store.Save(req, w, session) done <- err }(i) } // Collect results for i := range numConcurrentOps { err := <-done assert.NoError(t, err, "Concurrent operation %d should succeed", i) } } // TestParseConnOptions tests the base64 JSON parsing of connection options func TestParseConnOptions(t *testing.T) { tests := []struct { name string input string expected map[string]string expectError bool errorMsg string }{ { name: "simple key-value", input: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), expected: map[string]string{"target_session_attrs": "read-write"}, }, { name: "multiple options", input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)), expected: map[string]string{"connect_timeout": "10", "application_name": "authentik"}, }, { name: "numeric value as number", input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10}`)), expected: map[string]string{"connect_timeout": "10"}, }, { name: "boolean value", input: base64.StdEncoding.EncodeToString([]byte(`{"default_transaction_read_only":true}`)), expected: map[string]string{"default_transaction_read_only": "true"}, }, { name: "empty object", input: base64.StdEncoding.EncodeToString([]byte(`{}`)), expected: map[string]string{}, }, { name: "invalid base64", input: "not-valid-base64!!!", expectError: true, errorMsg: "invalid base64 encoding", }, { name: "invalid JSON", input: base64.StdEncoding.EncodeToString([]byte(`not json`)), expectError: true, errorMsg: "invalid JSON", }, { name: "JSON array instead of object", input: base64.StdEncoding.EncodeToString([]byte(`["value1", "value2"]`)), expectError: true, errorMsg: "invalid JSON", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := parseConnOptions(tt.input) if tt.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) } else { require.NoError(t, err) assert.Equal(t, tt.expected, result) } }) } } // TestApplyConnOptions tests that connection options are applied correctly to pgx.ConnConfig func TestApplyConnOptions(t *testing.T) { tests := []struct { name string opts map[string]string validate func(*testing.T, *pgx.ConnConfig) expectError bool errorMsg string }{ { name: "connect_timeout sets ConnectTimeout", opts: map[string]string{"connect_timeout": "30"}, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, 30*time.Second, cc.ConnectTimeout) }, }, { name: "target_session_attrs sets ValidateConnect", opts: map[string]string{"target_session_attrs": "read-write"}, validate: func(t *testing.T, cc *pgx.ConnConfig) { // target_session_attrs should NOT be in RuntimeParams _, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams") // It should set ValidateConnect instead assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set") expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(), runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name()) }, }, { name: "application_name goes to RuntimeParams", opts: map[string]string{"application_name": "my-app"}, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "my-app", cc.RuntimeParams["application_name"]) }, }, { name: "statement_timeout goes to RuntimeParams", opts: map[string]string{"statement_timeout": "5000"}, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "5000", cc.RuntimeParams["statement_timeout"]) }, }, { name: "unknown options go to RuntimeParams", opts: map[string]string{"custom_param": "custom_value"}, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "custom_value", cc.RuntimeParams["custom_param"]) }, }, { name: "multiple options", opts: map[string]string{ "connect_timeout": "10", "target_session_attrs": "read-write", "application_name": "authentik", }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, 10*time.Second, cc.ConnectTimeout) // target_session_attrs should NOT be in RuntimeParams _, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams") // It should set ValidateConnect instead assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set") assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, }, { name: "invalid connect_timeout", opts: map[string]string{"connect_timeout": "not-a-number"}, expectError: true, errorMsg: "invalid connect_timeout value", }, { name: "invalid target_session_attrs", opts: map[string]string{"target_session_attrs": "invalid-mode"}, expectError: true, errorMsg: "unknown target_session_attrs value", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create a base config connConfig, err := pgx.ParseConfig("") require.NoError(t, err) connConfig.RuntimeParams = make(map[string]string) err = applyConnOptions(connConfig, tt.opts) if tt.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) } else { require.NoError(t, err) tt.validate(t, connConfig) } }) } } // TestBuildConnConfig_Base64JSONConnOptions tests the full integration of base64 JSON connection options func TestBuildConnConfig_Base64JSONConnOptions(t *testing.T) { t.Run("bug report scenario - target_session_attrs", func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "authentik", Name: "authentik", ConnOptions: "eyJ0YXJnZXRfc2Vzc2lvbl9hdHRycyI6InJlYWQtd3JpdGUifQ==", } connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) // target_session_attrs should NOT be in RuntimeParams _, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") // It should set ValidateConnect instead assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set") expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(), runtime.FuncForPC(reflect.ValueOf(connConfig.ValidateConnect).Pointer()).Name()) }) t.Run("complex connection options", func(t *testing.T) { // {"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"} connOpts := base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"}`)) cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "authentik", Name: "authentik", ConnOptions: connOpts, } connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) assert.Equal(t, 10*time.Second, connConfig.ConnectTimeout) // target_session_attrs should NOT be in RuntimeParams _, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") // It should set ValidateConnect instead assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set") assert.Equal(t, "authentik-proxy", connConfig.RuntimeParams["application_name"]) }) } // Helper function to create session data JSON func createSessionData(t *testing.T, claims map[string]any) string { sessionData := map[string]any{ constants.SessionClaims: claims, } sessionDataJSON, err := json.Marshal(sessionData) require.NoError(t, err) return string(sessionDataJSON) } // generateTestCerts creates temporary SSL certificates for testing func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPath string, cleanup func()) { tmpDir := t.TempDir() // Generate CA certificate caKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) caTemplate := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Test CA"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, BasicConstraintsValid: true, IsCA: true, } caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) require.NoError(t, err) // Write CA certificate rootCertPath = filepath.Join(tmpDir, "root.crt") rootCertFile, err := os.Create(rootCertPath) require.NoError(t, err) defer func() { if closeErr := rootCertFile.Close(); closeErr != nil { t.Logf("failed to close root cert file: %v", closeErr) } }() err = pem.Encode(rootCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCertDER}) require.NoError(t, err) // Generate client key clientKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) // Generate client certificate clientTemplate := &x509.Certificate{ SerialNumber: big.NewInt(2), Subject: pkix.Name{ Organization: []string{"Test Client"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey) require.NoError(t, err) // Write client certificate clientCertPath = filepath.Join(tmpDir, "client.crt") clientCertFile, err := os.Create(clientCertPath) require.NoError(t, err) defer func() { if closeErr := clientCertFile.Close(); closeErr != nil { t.Logf("failed to close client cert file: %v", closeErr) } }() err = pem.Encode(clientCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}) require.NoError(t, err) // Write client key clientKeyPath = filepath.Join(tmpDir, "client.key") clientKeyFile, err := os.Create(clientKeyPath) require.NoError(t, err) defer func() { if closeErr := clientKeyFile.Close(); closeErr != nil { t.Logf("failed to close client key file: %v", closeErr) } }() clientKeyBytes := x509.MarshalPKCS1PrivateKey(clientKey) err = pem.Encode(clientKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: clientKeyBytes}) require.NoError(t, err) cleanup = func() { // TempDir cleanup is automatic in Go tests } return rootCertPath, clientCertPath, clientKeyPath, cleanup } // TestBuildConnConfig_WithBase64EncodedConnOptions demonstrates that ConnOptions // should be base64-encoded JSON but is currently being parsed as key=value pairs func TestBuildConnConfig_WithBase64EncodedConnOptions(t *testing.T) { tests := []struct { name string connOptions string expected map[string]string expectError bool }{ { name: "base64 encoded JSON with single parameter", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10"}`)), expected: map[string]string{ // connect_timeout is handled specially and NOT added to RuntimeParams }, }, { name: "base64 encoded JSON with multiple parameters", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik","statement_timeout":"30000"}`)), expected: map[string]string{ // connect_timeout is handled specially and NOT added to RuntimeParams "application_name": "authentik", "statement_timeout": "30000", }, }, { name: "base64 encoded JSON with special characters in values", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik proxy v2"}`)), expected: map[string]string{ "application_name": "authentik proxy v2", }, }, { name: "base64 encoded JSON with target_session_attrs", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write","application_name":"authentik"}`)), expected: map[string]string{ "application_name": "authentik", // target_session_attrs should NOT appear in RuntimeParams }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", ConnOptions: tt.connOptions, } result, err := BuildConnConfig(cfg) if tt.expectError { assert.Error(t, err) return } require.NoError(t, err) require.NotNil(t, result) // Verify that all expected parameters are present in RuntimeParams for key, expectedValue := range tt.expected { actualValue, exists := result.RuntimeParams[key] assert.True(t, exists, "Expected runtime parameter %s to exist", key) assert.Equal(t, expectedValue, actualValue, "Runtime parameter %s should have value %s", key, expectedValue) } // Verify that connect_timeout is handled specially (sets ConnectTimeout field, not RuntimeParams) if tt.name == "base64 encoded JSON with single parameter" || tt.name == "base64 encoded JSON with multiple parameters" { _, hasConnectTimeout := result.RuntimeParams["connect_timeout"] assert.False(t, hasConnectTimeout, "connect_timeout should not appear in RuntimeParams") assert.Equal(t, 10*time.Second, result.ConnectTimeout, "connect_timeout should be set as ConnectTimeout duration") } // Verify that target_session_attrs is NOT in RuntimeParams // (it affects connection behavior, not a runtime param) _, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") }) } } // Verifies DefaultSchema is applied via AfterConnect and never via startup RuntimeParams. func TestBuildConnConfig_SearchPath_DefaultSchema(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "authentik", Name: "authentik", DefaultSchema: "default_schema", } connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) require.NotNil(t, connConfig.AfterConnect) _, hasSearchPath := connConfig.RuntimeParams["search_path"] assert.False(t, hasSearchPath, "search_path should not appear in RuntimeParams") } // Verifies ConnOptions search_path is ignored and excluded from startup RuntimeParams. func TestBuildConnConfig_SearchPath_ConnOptions(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "authentik", Name: "authentik", ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"search_path":"connopt_schema"}`)), } connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) assert.Nil(t, connConfig.AfterConnect) _, hasSearchPath := connConfig.RuntimeParams["search_path"] assert.False(t, hasSearchPath, "search_path should not appear in RuntimeParams") } // Verifies ConnOptions search_path does not override DefaultSchema and other conn options still apply. func TestBuildConnConfig_SearchPath_ConnOptionsIgnoredWhenDefaultSchemaSet(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "authentik", Name: "authentik", DefaultSchema: "default_schema", ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"search_path":"override_schema","application_name":"authentik-proxy"}`)), } connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) require.NotNil(t, connConfig.AfterConnect) assert.Equal(t, "authentik-proxy", connConfig.RuntimeParams["application_name"]) _, hasSearchPath := connConfig.RuntimeParams["search_path"] assert.False(t, hasSearchPath, "search_path should not appear in RuntimeParams") } // Verifies inherited search_path from pgx/libpq defaults is removed from startup RuntimeParams. func TestBuildConnConfig_SearchPath_InheritedServiceSetting(t *testing.T) { serviceFile := filepath.Join(t.TempDir(), "pg_service.conf") err := os.WriteFile(serviceFile, []byte("[authentik-test]\nsearch_path=service_schema\n"), 0o600) require.NoError(t, err) t.Setenv("PGSERVICE", "authentik-test") t.Setenv("PGSERVICEFILE", serviceFile) cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "authentik", Name: "authentik", } connConfig, err := BuildConnConfig(cfg) require.NoError(t, err) require.NotNil(t, connConfig.AfterConnect) _, hasSearchPath := connConfig.RuntimeParams["search_path"] assert.False(t, hasSearchPath, "search_path should not appear in RuntimeParams") } // TestBuildConnConfig_TargetSessionAttrs demonstrates how target_session_attrs // should be properly handled using pgx's ValidateConnect callback func TestBuildConnConfig_TargetSessionAttrs(t *testing.T) { tests := []struct { name string connOptions string targetSessionAttrs string expectedValidator pgconn.ValidateConnectFunc validatorDescription string }{ { name: "target_session_attrs=read-write", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), targetSessionAttrs: "read-write", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, validatorDescription: "should validate connection is read-write by checking transaction_read_only=off", }, { name: "target_session_attrs=read-only", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-only"}`)), targetSessionAttrs: "read-only", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadOnly, validatorDescription: "should validate connection is read-only by checking transaction_read_only=on", }, { name: "target_session_attrs=primary", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"primary"}`)), targetSessionAttrs: "primary", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPrimary, validatorDescription: "should validate connection is to primary by checking in_hot_standby=off", }, { name: "target_session_attrs=standby", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"standby"}`)), targetSessionAttrs: "standby", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsStandby, validatorDescription: "should validate connection is to standby by checking in_hot_standby=on", }, { name: "target_session_attrs=prefer-standby", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"prefer-standby"}`)), targetSessionAttrs: "prefer-standby", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, validatorDescription: "should prefer standby connections (affects fallback logic)", }, { name: "target_session_attrs=any (default)", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"any"}`)), targetSessionAttrs: "any", expectedValidator: nil, validatorDescription: "should not set validator as any connection is acceptable", }, { name: "no target_session_attrs", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)), targetSessionAttrs: "", expectedValidator: nil, validatorDescription: "should not set validator when target_session_attrs is not specified", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", Port: "5432", User: "testuser", Name: "testdb", ConnOptions: tt.connOptions, } result, err := BuildConnConfig(cfg) require.NoError(t, err) require.NotNil(t, result) // Verify target_session_attrs is NOT in RuntimeParams _, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") // Verify ValidateConnect callback is set to the correct standard pgx function if tt.expectedValidator != nil { require.NotNil(t, result.ValidateConnect, "ValidateConnect should be set for target_session_attrs=%s: %s", tt.targetSessionAttrs, tt.validatorDescription) // Compare function pointers using reflect to check if it's the same function actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer()) expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer()) assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(), "ValidateConnect should be set to %s for target_session_attrs=%s", expectedFuncPtr.Name(), tt.targetSessionAttrs) t.Logf("Expected validator: %s", expectedFuncPtr.Name()) t.Logf("Actual validator: %s", actualFuncPtr.Name()) } else { assert.Nil(t, result.ValidateConnect, "ValidateConnect should not be set: %s", tt.validatorDescription) } }) } } // TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts tests that when multiple // hosts are specified, fallbacks are properly configured along with the validator func TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts(t *testing.T) { tests := []struct { name string host string port string sslMode string connOptions string targetSessionAttrs string expectedValidator pgconn.ValidateConnectFunc expectedPrimaryHost string expectedPrimaryPort uint16 expectedFallbacks []*pgconn.FallbackConfig expectTLS bool validatorDescription string }{ { name: "multiple hosts with read-write", host: "db1.local,db2.local,db3.local", port: "5432", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), targetSessionAttrs: "read-write", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, expectedPrimaryHost: "db1.local", expectedPrimaryPort: 5432, expectedFallbacks: []*pgconn.FallbackConfig{ {Host: "db2.local", Port: 5432, TLSConfig: nil}, {Host: "db3.local", Port: 5432, TLSConfig: nil}, }, expectTLS: false, validatorDescription: "should set validator and create fallbacks for additional hosts", }, { name: "multiple hosts with ports specified", host: "db1.local,db2.local,db3.local", port: "5432,5433,5434", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), targetSessionAttrs: "read-write", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, expectedPrimaryHost: "db1.local", expectedPrimaryPort: 5432, expectedFallbacks: []*pgconn.FallbackConfig{ {Host: "db2.local", Port: 5433, TLSConfig: nil}, {Host: "db3.local", Port: 5434, TLSConfig: nil}, }, expectTLS: false, validatorDescription: "should handle hosts with explicit ports", }, { name: "multiple hosts with TLS required", host: "db1.local,db2.local,db3.local", port: "5432", sslMode: "require", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"require"}`)), targetSessionAttrs: "read-write", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, expectedPrimaryHost: "db1.local", expectedPrimaryPort: 5432, expectedFallbacks: []*pgconn.FallbackConfig{ {Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil) {Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil) }, expectTLS: true, validatorDescription: "should set TLS config for all hosts when sslmode=require", }, { name: "multiple hosts with TLS verify-full", host: "db1.local,db2.local,db3.local", port: "5432", sslMode: "require", connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"verify-full"}`)), targetSessionAttrs: "read-write", expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, expectedPrimaryHost: "db1.local", expectedPrimaryPort: 5432, expectedFallbacks: []*pgconn.FallbackConfig{ {Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil) {Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil) }, expectTLS: true, validatorDescription: "should set TLS config host name for all hosts when sslmode=verify-full", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: tt.host, Port: tt.port, User: "testuser", Name: "testdb", SSLMode: tt.sslMode, ConnOptions: tt.connOptions, } result, err := BuildConnConfig(cfg) require.NoError(t, err) require.NotNil(t, result) // Verify target_session_attrs is NOT in RuntimeParams _, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"] assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") // Verify ValidateConnect is set to the correct function require.NotNil(t, result.ValidateConnect, "ValidateConnect should be set for target_session_attrs=%s with multiple hosts", tt.targetSessionAttrs) actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer()) expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer()) assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(), "ValidateConnect should be %s for target_session_attrs=%s", expectedFuncPtr.Name(), tt.targetSessionAttrs) // Verify the primary host and port assert.Equal(t, tt.expectedPrimaryHost, result.Host, "Primary host should be %s", tt.expectedPrimaryHost) assert.Equal(t, tt.expectedPrimaryPort, result.Port, "Primary port should be %d", tt.expectedPrimaryPort) // Verify primary TLSConfig based on sslmode if tt.expectTLS { assert.NotNil(t, result.TLSConfig, "Primary connection should have TLSConfig set when sslmode=%s", tt.sslMode) } else { assert.Nil(t, result.TLSConfig, "Primary connection should not have TLSConfig when sslmode is not set") } // Verify Fallbacks are configured for the additional hosts require.Len(t, result.Fallbacks, len(tt.expectedFallbacks), "Should have %d fallback configs for the additional hosts", len(tt.expectedFallbacks)) // Verify each fallback configuration for i, expectedFb := range tt.expectedFallbacks { actualFb := result.Fallbacks[i] assert.Equal(t, expectedFb.Host, actualFb.Host, "Fallback %d host should be %s", i+1, expectedFb.Host) assert.Equal(t, expectedFb.Port, actualFb.Port, "Fallback %d port should be %d", i+1, expectedFb.Port) // Verify TLSConfig is set appropriately for fallbacks if tt.expectTLS { assert.NotNil(t, actualFb.TLSConfig, "Fallback %d should have TLSConfig set when sslmode=%s", i+1, tt.sslMode) // Verify InsecureSkipVerify for sslmode=require switch tt.sslMode { case "require": assert.True(t, actualFb.TLSConfig.InsecureSkipVerify, "Fallback %d TLSConfig should have InsecureSkipVerify=true for sslmode=require", i+1) case "verify-full": assert.False(t, actualFb.TLSConfig.InsecureSkipVerify, "Fallback %d TLSConfig should have InsecureSkipVerify=false for sslmode=verify-full", i+1) assert.Equal(t, actualFb.Host, actualFb.TLSConfig.ServerName, "Fallback %d TLSConfig ServerName should match host for sslmode=verify-full", i+1) } } else { assert.Nil(t, actualFb.TLSConfig, "Fallback %d should not have TLSConfig when sslmode is not set", i+1) } } // Log the configuration for debugging t.Logf("Primary host: %s:%d", result.Host, result.Port) t.Logf("Validator: %s", actualFuncPtr.Name()) for i, fb := range result.Fallbacks { t.Logf("Fallback %d: %s:%d", i+1, fb.Host, fb.Port) } }) } } // TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs tests that multiple hosts // create fallbacks even without target_session_attrs func TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "db1.local,db2.local,db3.local", Port: "5432", User: "testuser", Name: "testdb", } result, err := BuildConnConfig(cfg) require.NoError(t, err) require.NotNil(t, result) // Verify primary host assert.Equal(t, "db1.local", result.Host) assert.Equal(t, uint16(5432), result.Port) // Verify fallbacks are created require.Len(t, result.Fallbacks, 2, "Should have 2 fallback configs") assert.Equal(t, "db2.local", result.Fallbacks[0].Host) assert.Equal(t, uint16(5432), result.Fallbacks[0].Port) assert.Equal(t, "db3.local", result.Fallbacks[1].Host) assert.Equal(t, uint16(5432), result.Fallbacks[1].Port) // Verify no ValidateConnect is set (no target_session_attrs) assert.Nil(t, result.ValidateConnect) } // TestBuildConnConfig_CommaSeparatedPorts_EdgeCases tests edge cases and error scenarios for comma-separated ports func TestBuildConnConfig_CommaSeparatedPorts_EdgeCases(t *testing.T) { tests := []struct { name string host string port string expectError bool errorContains string expectedHost string expectedPort uint16 expectedFallbacks []*pgconn.FallbackConfig }{ { name: "invalid port in comma-separated list", host: "db1.local,db2.local", port: "5432,abc", expectError: true, errorContains: "invalid port value", }, { name: "port out of range (too high)", host: "db1.local,db2.local", port: "5432,99999", expectError: true, errorContains: "PostgreSQL port 99999 is out of valid range", }, { name: "port out of range (zero)", host: "db1.local,db2.local", port: "5432,0", expectError: true, errorContains: "PostgreSQL port 0 must be positive", }, { name: "empty port string", host: "db1.local", port: "", expectError: true, errorContains: "PostgreSQL port is required", }, { name: "port with only whitespace", host: "db1.local", port: " ", expectError: true, errorContains: "invalid port value", }, { name: "mismatched number of hosts and ports", host: "db1.local,db2.local", port: "5432", expectError: false, expectedHost: "db1.local", expectedPort: 5432, expectedFallbacks: []*pgconn.FallbackConfig{ {Host: "db2.local", Port: 5432}, }, }, { name: "extra ports than hosts", host: "db1.local", port: "5432,5433", expectError: false, expectedHost: "db1.local", expectedPort: 5432, expectedFallbacks: []*pgconn.FallbackConfig{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: tt.host, Port: tt.port, User: "testuser", Name: "testdb", } c, err := BuildConnConfig(cfg) if tt.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tt.errorContains) } else { require.NoError(t, err) require.NotNil(t, c) assert.Equal(t, tt.expectedHost, c.Host) assert.Equal(t, tt.expectedPort, c.Port) require.Len(t, c.Fallbacks, len(tt.expectedFallbacks)) for i, expectedFb := range tt.expectedFallbacks { actualFb := c.Fallbacks[i] assert.Equal(t, expectedFb.Host, actualFb.Host) assert.Equal(t, expectedFb.Port, actualFb.Port) } } }) } }