package postgresstore import ( "context" "crypto/tls" "crypto/x509" "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "os" "strconv" "strings" "time" "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/stdlib" "github.com/mitchellh/mapstructure" log "github.com/sirupsen/logrus" "gorm.io/gorm" "gorm.io/gorm/clause" "goauthentik.io/internal/config" "goauthentik.io/internal/outpost/proxyv2/constants" "goauthentik.io/internal/outpost/proxyv2/types" ) // PostgresStore stores gorilla sessions in PostgreSQL using GORM type PostgresStore struct { db *gorm.DB pool *RefreshableConnPool // Keep reference to pool for cleanup // default options to use when a new session is created options sessions.Options // key prefix with which the session will be stored keyPrefix string log *log.Entry } // ProxySession represents the session data structure in PostgreSQL type ProxySession struct { UUID uuid.UUID `gorm:"type:uuid;primaryKey;column:uuid;default:gen_random_uuid()"` SessionKey string `gorm:"column:session_key"` UserID *uuid.UUID `gorm:"column:user_id"` SessionData string `gorm:"type:jsonb;column:session_data"` Expires time.Time `gorm:"column:expires"` Expiring bool `gorm:"column:expiring"` } // TableName specifies the table name for GORM func (ProxySession) TableName() string { return "authentik_providers_proxy_proxysession" } // BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration. func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { // Validate required fields if cfg.Host == "" { return nil, fmt.Errorf("PostgreSQL host is required") } if cfg.User == "" { return nil, fmt.Errorf("PostgreSQL user is required") } if cfg.Name == "" { return nil, fmt.Errorf("PostgreSQL database name is required") } if cfg.Port == "" { return nil, fmt.Errorf("PostgreSQL port is required") } // Start with a default config connConfig, err := pgx.ParseConfig("") if err != nil { return nil, fmt.Errorf("failed to create default config: %w", err) } // Parse comma-separated hosts and create fallbacks // cfg.Host can be a comma-separated list like "host1,host2,host3" hosts := strings.Split(cfg.Host, ",") for i, host := range hosts { hosts[i] = strings.TrimSpace(host) } // Parse and validate comma-separated ports portStrs := strings.Split(cfg.Port, ",") ports := make([]uint16, len(portStrs)) for i, portStr := range portStrs { portStr = strings.TrimSpace(portStr) port, err := strconv.Atoi(portStr) if err != nil { return nil, fmt.Errorf("invalid port value %q: %w", portStr, err) } if port <= 0 { return nil, fmt.Errorf("PostgreSQL port %d must be positive", port) } if port > 65535 { return nil, fmt.Errorf("PostgreSQL port %d is out of valid range", port) } ports[i] = uint16(port) } // Get port for primary host primaryHost := hosts[0] primaryPort := ports[0] // Set connection parameters for primary host connConfig.Host = primaryHost connConfig.Port = primaryPort connConfig.User = cfg.User connConfig.Password = cfg.Password connConfig.Database = cfg.Name // Configure TLS/SSL if cfg.SSLMode != "" { switch cfg.SSLMode { case "disable": connConfig.TLSConfig = nil case "require", "verify-ca", "verify-full": tlsConfig := &tls.Config{} // Load root CA certificate if provided if cfg.SSLRootCert != "" { caCert, err := os.ReadFile(cfg.SSLRootCert) if err != nil { return nil, fmt.Errorf("failed to read SSL root certificate: %w", err) } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { return nil, fmt.Errorf("failed to parse SSL root certificate") } tlsConfig.RootCAs = caCertPool } // Load client certificate and key if provided if cfg.SSLCert != "" && cfg.SSLKey != "" { cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey) if err != nil { return nil, fmt.Errorf("failed to load SSL client certificate: %w", err) } tlsConfig.Certificates = []tls.Certificate{cert} } // Set verification mode switch cfg.SSLMode { case "require": // Don't verify the server certificate (just encrypt) tlsConfig.InsecureSkipVerify = true case "verify-ca": // Verify the certificate is signed by a trusted CA tlsConfig.InsecureSkipVerify = false case "verify-full": // Verify the certificate and hostname tlsConfig.InsecureSkipVerify = false tlsConfig.ServerName = primaryHost } connConfig.TLSConfig = tlsConfig } } // Create fallback configurations for additional hosts if len(hosts) > 1 { connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1) for i, host := range hosts[1:] { port := getPortForIndex(ports, i+1) fallback := &pgconn.FallbackConfig{ Host: host, Port: port, } // Copy TLS config to fallback if present if connConfig.TLSConfig != nil { fallbackTLS := connConfig.TLSConfig.Clone() // Update ServerName for verify-full mode if cfg.SSLMode == "verify-full" { fallbackTLS.ServerName = host } fallback.TLSConfig = fallbackTLS } connConfig.Fallbacks = append(connConfig.Fallbacks, fallback) } } // Set runtime params if connConfig.RuntimeParams == nil { connConfig.RuntimeParams = make(map[string]string) } effectiveSearchPath := cfg.DefaultSchema // Parse and apply connection options if specified if cfg.ConnOptions != "" { connOpts, err := parseConnOptions(cfg.ConnOptions) if err != nil { return nil, fmt.Errorf("failed to parse connection options: %w", err) } // search_path from ConnOptions is not supported here; Django controls schema selection. // Always remove it so it cannot end up in startup RuntimeParams via applyConnOptions. delete(connOpts, "search_path") if err := applyConnOptions(connConfig, connOpts); err != nil { return nil, fmt.Errorf("failed to apply connection options: %w", err) } } // search_path may already be present via pgx/libpq inherited defaults (e.g. service files). // Always remove it from startup RuntimeParams; apply it via AfterConnect instead. if inheritedSearchPath, hasInheritedSearchPath := connConfig.RuntimeParams["search_path"]; hasInheritedSearchPath { if effectiveSearchPath == "" { effectiveSearchPath = inheritedSearchPath } delete(connConfig.RuntimeParams, "search_path") } // Set search_path after connection startup to avoid startup-parameter issues with PgBouncer. if effectiveSearchPath != "" { connConfig.AfterConnect = func(ctx context.Context, pgConn *pgconn.PgConn) error { result := pgConn.ExecParams( ctx, "select pg_catalog.set_config('search_path', $1, false)", [][]byte{[]byte(effectiveSearchPath)}, nil, nil, nil, ).Read() return result.Err } } return connConfig, nil } // getPortForIndex returns the port for the given host index. // If there are fewer ports than needed, returns the last port (libpq behavior). func getPortForIndex(ports []uint16, i int) uint16 { if i >= len(ports) { return ports[len(ports)-1] } return ports[i] } // parseConnOptions decodes a base64-encoded JSON string into a map of connection options. // This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json func parseConnOptions(encoded string) (map[string]string, error) { // Base64 decode decoded, err := base64.StdEncoding.DecodeString(encoded) if err != nil { return nil, fmt.Errorf("invalid base64 encoding: %w", err) } // Parse JSON var opts map[string]any if err := json.Unmarshal(decoded, &opts); err != nil { return nil, fmt.Errorf("invalid JSON: %w", err) } // Convert all values to strings result := make(map[string]string) for k, v := range opts { switch val := v.(type) { case string: result[k] = val case float64: // JSON numbers are float64 if val == float64(int(val)) { result[k] = strconv.Itoa(int(val)) } else { result[k] = strconv.FormatFloat(val, 'f', -1, 64) } case bool: result[k] = strconv.FormatBool(val) default: result[k] = fmt.Sprintf("%v", v) } } return result, nil } // applyConnOptions applies parsed connection options to the pgx.ConnConfig. func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error { for key, value := range opts { // connect_timeout needs special handling as it's a connection-level timeout if key == "connect_timeout" { timeout, err := strconv.Atoi(value) if err != nil { return fmt.Errorf("invalid connect_timeout value: %w", err) } connConfig.ConnectTimeout = time.Duration(timeout) * time.Second continue } // target_session_attrs needs special handling to set ValidateConnect function if key == "target_session_attrs" { switch value { case "read-write": connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite case "read-only": connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly case "primary": connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary case "standby": connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby case "prefer-standby": connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby case "any": // "any" is the default (no validation needed) connConfig.ValidateConnect = nil default: return fmt.Errorf("unknown target_session_attrs value: %s", value) } // Do not add target_session_attrs to RuntimeParams continue } // All other options go to RuntimeParams connConfig.RuntimeParams[key] = value } return nil } // BuildDSN constructs a PostgreSQL connection string from a ConnConfig. func BuildDSN(cfg config.PostgreSQLConfig) (string, error) { connConfig, err := BuildConnConfig(cfg) if err != nil { return "", err } // Register the config and get a connection string // (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password) return stdlib.RegisterConnConfig(connConfig), nil } // SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool. // This is the standardized way to create database connections for both production and tests. // // The RefreshableConnPool wraps database/sql and automatically detects PostgreSQL // authentication errors (SQLSTATE 28xxx), refreshes credentials from config sources // (file://, env://, or plain environment variables), and reconnects without downtime. // // Parameters: // - cfg: PostgreSQL configuration (host, port, user, password, etc.) // - gormConfig: GORM configuration (logger, naming strategy, etc.) // - maxIdleConns: Maximum number of idle connections in the pool // - maxOpenConns: Maximum number of open connections to the database // - connMaxLifetime: Maximum lifetime of a connection // // Returns: // - *gorm.DB: GORM database instance for ORM operations // - *RefreshableConnPool: Connection pool reference (caller must Close when done) // - error: Any error encountered during setup func SetupGORMWithRefreshablePool(cfg config.PostgreSQLConfig, gormConfig *gorm.Config, maxIdleConns, maxOpenConns int, connMaxLifetime time.Duration) (*gorm.DB, *RefreshableConnPool, error) { // Build connection string dsn, err := BuildDSN(cfg) if err != nil { return nil, nil, fmt.Errorf("failed to build DSN: %w", err) } // Create refreshable connection pool pool, err := NewRefreshableConnPool(dsn, gormConfig, maxIdleConns, maxOpenConns, connMaxLifetime) if err != nil { return nil, nil, fmt.Errorf("failed to create connection pool: %w", err) } // Create GORM DB using the refreshable connection pool db, err := pool.NewGORMDB() if err != nil { _ = pool.Close() return nil, nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err) } // Test the connection with a simple query // This will trigger the connection pool's tryWithRefresh logic if there's an auth error ctx := context.Background() var result int err = db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error if err != nil { _ = pool.Close() return nil, nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err) } return db, pool, nil } // NewPostgresStore returns a new PostgresStore func NewPostgresStore(log *log.Entry) (*PostgresStore, error) { cfg := config.Get().PostgreSQL // Configure GORM gormConfig := &gorm.Config{ Logger: NewLogger(log), NowFunc: func() time.Time { return time.Now().UTC() }, } // Determine connection pool settings maxIdleConns := 4 maxOpenConns := 4 var connMaxLifetime time.Duration if cfg.ConnMaxAge > 0 { connMaxLifetime = time.Duration(cfg.ConnMaxAge) * time.Second } else { connMaxLifetime = time.Hour // Default 1 hour } // Use standardized setup db, pool, err := SetupGORMWithRefreshablePool(cfg, gormConfig, maxIdleConns, maxOpenConns, connMaxLifetime) if err != nil { return nil, fmt.Errorf("failed to setup database: %w", err) } ps := &PostgresStore{ db: db, pool: pool, options: sessions.Options{ Path: "/", MaxAge: 86400 * 30, // 30 days default (but overwritten in postgresstore creation based on token validation) }, keyPrefix: "authentik_proxy_session_", log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore"), } return ps, nil } // Get returns a session for the given name after adding it to the registry. func (s *PostgresStore) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(s, name) } // New returns a session for the given name without adding it to the registry. func (s *PostgresStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) opts := s.options session.Options = &opts session.IsNew = true c, err := r.Cookie(name) if err != nil { return session, nil } session.ID = c.Value err = s.load(r.Context(), session) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return session, nil } return session, err } session.IsNew = false return session, err } // Save adds a single session to the response. func (s *PostgresStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { // Delete if max-age is <= 0 if session.Options.MaxAge <= 0 { if err := s.delete(r.Context(), session); err != nil { return fmt.Errorf("failed to delete session: %w", err) } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } if session.ID == "" { // Generate new session ID session.ID = s.keyPrefix + generateSessionID() } if err := s.save(r.Context(), session); err != nil { return fmt.Errorf("failed to save session: %w", err) } http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options)) return nil } // Options set options to use when a new session is created func (s *PostgresStore) Options(opts sessions.Options) { s.options = opts } // KeyPrefix sets the key prefix to store session in PostgreSQL func (s *PostgresStore) KeyPrefix(keyPrefix string) { s.keyPrefix = keyPrefix } // Close closes the PostgreSQL store func (s *PostgresStore) Close() error { if s.pool != nil { return s.pool.Close() } return nil } // save writes session to PostgreSQL func (s *PostgresStore) save(ctx context.Context, session *sessions.Session) error { // Convert session.Values (map[interface{}]interface{}) to map[string]interface{} for JSON marshaling stringKeyedValues := make(map[string]any) for k, v := range session.Values { if key, ok := k.(string); ok { stringKeyedValues[key] = v } } // Serialize all session values to JSON sessionData, err := json.Marshal(stringKeyedValues) if err != nil { return fmt.Errorf("failed to marshal session values: %w", err) } // Extract user ID from claims if it exists var userID *uuid.UUID if claims, hasClaims := session.Values[constants.SessionClaims]; hasClaims { if claimsMap, ok := claims.(map[string]any); ok { if sub, exists := claimsMap["sub"]; exists { if subStr, ok := sub.(string); ok { if parsedUUID, err := uuid.Parse(subStr); err == nil { userID = &parsedUUID } } } } } proxySession := ProxySession{ UUID: uuid.New(), SessionKey: session.ID, UserID: userID, SessionData: string(sessionData), Expiring: true, } // Add expiration timestamp to session data if session.Options != nil && session.Options.MaxAge > 0 { expiresAt := time.Now().UTC().Add(time.Duration(session.Options.MaxAge) * time.Second) proxySession.Expires = expiresAt } return s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "session_key"}}, DoUpdates: clause.AssignmentColumns([]string{"user_id", "session_data", "expires"}), }).Create(&proxySession).Error } // load reads session from PostgreSQL func (s *PostgresStore) load(ctx context.Context, session *sessions.Session) error { var proxySession ProxySession err := s.db.WithContext(ctx).Where("session_key = ?", session.ID).First(&proxySession).Error if err != nil { return fmt.Errorf("failed to load session: %w", err) } // Check if session is expired if time.Now().UTC().After(proxySession.Expires) { // Session is expired, delete it and return not found error s.db.WithContext(ctx).Delete(&ProxySession{}, "session_key = ?", session.ID) return gorm.ErrRecordNotFound } // Deserialize session data from JSON if proxySession.SessionData != "" { // First unmarshal to map[string]interface{} var stringKeyedValues map[string]any err = json.Unmarshal([]byte(proxySession.SessionData), &stringKeyedValues) if err != nil { return fmt.Errorf("failed to unmarshal session data: %w", err) } // Convert back to map[interface{}]interface{} for gorilla/sessions compatibility session.Values = make(map[any]any) for k, v := range stringKeyedValues { session.Values[k] = v } } return nil } // delete removes session from PostgreSQL func (s *PostgresStore) delete(ctx context.Context, session *sessions.Session) error { return s.db.WithContext(ctx).Delete(&ProxySession{}, "session_key = ?", session.ID).Error } // CleanupExpired removes expired sessions by checking MaxAge in session_data func (s *PostgresStore) CleanupExpired(ctx context.Context) error { result := s.db.WithContext(ctx).Where(`"expires" < ?`, time.Now().UTC()).Delete(&ProxySession{}) if result.Error != nil { return fmt.Errorf("failed to delete expired sessions: %w", result.Error) } if result.RowsAffected > 0 { s.log.WithField("count", result.RowsAffected).Info("Cleaned up expired sessions") } return nil } // LogoutSessions removes sessions that match the given filter criteria // The filter function should return true for sessions that should be deleted func (s *PostgresStore) LogoutSessions(ctx context.Context, filter func(c types.Claims) bool) error { // First, try to use JSONB operators for common filter patterns to avoid N+1 queries // If the filter is too complex, fall back to client-side filtering // Pre-filter sessions using JSONB operators where possible // Only fetch sessions that have claims (session_data->'claims' IS NOT NULL) var sessions []ProxySession err := s.db.WithContext(ctx).Where(fmt.Sprintf("session_data::jsonb ? '%s'", constants.SessionClaims)).Find(&sessions).Error if err != nil { return fmt.Errorf("failed to fetch sessions: %w", err) } var sessionKeysToDelete []string for _, session := range sessions { if session.SessionData == "" { continue } var sessionData map[string]any if err := json.Unmarshal([]byte(session.SessionData), &sessionData); err != nil { continue } claimsData, hasClaims := sessionData[constants.SessionClaims] if !hasClaims { continue } claimsMap, ok := claimsData.(map[string]any) if !ok { continue } // Only decode Sub and Sid fields since those are the only ones used in filters var claims types.Claims if err := mapstructure.Decode(claimsMap, &claims); err != nil { continue } if filter(claims) { sessionKeysToDelete = append(sessionKeysToDelete, session.SessionKey) } } if len(sessionKeysToDelete) > 0 { err = s.db.WithContext(ctx).Delete(&ProxySession{}, "session_key IN ?", sessionKeysToDelete).Error if err != nil { return fmt.Errorf("failed to delete sessions: %w", err) } } return nil } // generateSessionID generates a random session ID func generateSessionID() string { return uuid.New().String() } // NewTestStore creates a PostgresStore for testing with the given database and pool. // The pool reference is required to properly close connections in test cleanup. func NewTestStore(db *gorm.DB, pool *RefreshableConnPool) *PostgresStore { return &PostgresStore{ db: db, pool: pool, options: sessions.Options{ Path: "/", MaxAge: 3600, }, keyPrefix: "test_session_", log: log.WithField("logger", "test"), } }