diff --git a/.dockerignore b/.dockerignore index bfdb9d241b..29cbbac9d1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,3 +11,4 @@ blueprints/local !gen-ts-api/node_modules !gen-ts-api/dist/** !gen-go-api/ +.venv diff --git a/authentik/outposts/consumer.py b/authentik/outposts/consumer.py index 80b64999d5..58b53e24db 100644 --- a/authentik/outposts/consumer.py +++ b/authentik/outposts/consumer.py @@ -37,6 +37,9 @@ class WebsocketMessageInstruction(IntEnum): # Provider specific message PROVIDER_SPECIFIC = 3 + # Session ended + SESSION_END = 4 + @dataclass(slots=True) class WebsocketMessage: @@ -145,6 +148,14 @@ class OutpostConsumer(JsonWebsocketConsumer): asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) ) + def event_session_end(self, event): + """Event handler which is called when a session is ended""" + self.send_json( + asdict( + WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event) + ) + ) + def event_provider_specific(self, event): """Event handler which can be called by provider-specific implementations to send specific messages to the outpost""" diff --git a/authentik/outposts/signals.py b/authentik/outposts/signals.py index 73d05a4b9a..08c3ae668e 100644 --- a/authentik/outposts/signals.py +++ b/authentik/outposts/signals.py @@ -1,17 +1,24 @@ """authentik outpost signals""" +from django.contrib.auth.signals import user_logged_out from django.core.cache import cache from django.db.models import Model from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.dispatch import receiver +from django.http import HttpRequest from structlog.stdlib import get_logger from authentik.brands.models import Brand -from authentik.core.models import Provider +from authentik.core.models import AuthenticatedSession, Provider, User from authentik.crypto.models import CertificateKeyPair from authentik.lib.utils.reflection import class_to_path from authentik.outposts.models import Outpost, OutpostServiceConnection -from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save +from authentik.outposts.tasks import ( + CACHE_KEY_OUTPOST_DOWN, + outpost_controller, + outpost_post_save, + outpost_session_end, +) LOGGER = get_logger() UPDATE_TRIGGERING_MODELS = ( @@ -73,3 +80,17 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): instance.user.delete() cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) + + +@receiver(user_logged_out) +def logout_revoke_direct(sender: type[User], request: HttpRequest, **_): + """Catch logout by direct logout and forward to providers""" + if not request.session or not request.session.session_key: + return + outpost_session_end.delay(request.session.session_key) + + +@receiver(pre_delete, sender=AuthenticatedSession) +def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): + """Catch logout by expiring sessions being deleted""" + outpost_session_end.delay(instance.session.session_key) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index e09dcf769f..fe716ff455 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -1,5 +1,6 @@ """outpost tasks""" +from hashlib import sha256 from os import R_OK, access from pathlib import Path from socket import gethostname @@ -49,6 +50,11 @@ LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" +def hash_session_key(session_key: str) -> str: + """Hash the session key for sending session end signals""" + return sha256(session_key.encode("ascii")).hexdigest() + + def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: """Get a controller for the outpost, when a service connection is defined""" if not outpost.service_connection: @@ -289,3 +295,20 @@ def outpost_connection_discovery(self: SystemTask): url=unix_socket_path, ) self.set_status(TaskStatus.SUCCESSFUL, *messages) + + +@CELERY_APP.task() +def outpost_session_end(session_id: str): + """Update outpost instances connected to a single outpost""" + layer = get_channel_layer() + hashed_session_id = hash_session_key(session_id) + for outpost in Outpost.objects.all(): + LOGGER.info("Sending session end signal to outpost", outpost=outpost) + group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} + async_to_sync(layer.group_send)( + group, + { + "type": "event.session.end", + "session_id": hashed_session_id, + }, + ) diff --git a/authentik/providers/proxy/signals.py b/authentik/providers/proxy/signals.py deleted file mode 100644 index 48b9f9794d..0000000000 --- a/authentik/providers/proxy/signals.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Proxy provider signals""" - -from django.contrib.auth.signals import user_logged_out -from django.db.models.signals import pre_delete -from django.dispatch import receiver -from django.http import HttpRequest - -from authentik.core.models import AuthenticatedSession, User -from authentik.providers.proxy.tasks import proxy_on_logout - - -@receiver(user_logged_out) -def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_): - """Catch logout by direct logout and forward to proxy providers""" - if not request.session or not request.session.session_key: - return - proxy_on_logout.delay(request.session.session_key) - - -@receiver(pre_delete, sender=AuthenticatedSession) -def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): - """Catch logout by expiring sessions being deleted""" - proxy_on_logout.delay(instance.session.session_key) diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py deleted file mode 100644 index 7051619b5e..0000000000 --- a/authentik/providers/proxy/tasks.py +++ /dev/null @@ -1,26 +0,0 @@ -"""proxy provider tasks""" - -from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer - -from authentik.outposts.consumer import OUTPOST_GROUP -from authentik.outposts.models import Outpost, OutpostType -from authentik.providers.oauth2.id_token import hash_session_key -from authentik.root.celery import CELERY_APP - - -@CELERY_APP.task() -def proxy_on_logout(session_id: str): - """Update outpost instances connected to a single outpost""" - layer = get_channel_layer() - hashed_session_id = hash_session_key(session_id) - for outpost in Outpost.objects.filter(type=OutpostType.PROXY): - group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} - async_to_sync(layer.group_send)( - group, - { - "type": "event.provider.specific", - "sub_type": "logout", - "session_id": hashed_session_id, - }, - ) diff --git a/go.mod b/go.mod index 41ca5e1fcf..283328a580 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.24.0 require ( beryju.io/ldap v0.1.0 + github.com/avast/retry-go/v4 v4.6.1 github.com/coreos/go-oidc/v3 v3.14.1 github.com/getsentry/sentry-go v0.33.0 github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 diff --git a/go.sum b/go.sum index cca3ad165a..7b2ebb4d04 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7V github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk= +github.com/avast/retry-go/v4 v4.6.1/go.mod h1:V6oF8njAwxJ5gRo1Q7Cxab24xs5NCWZBeaHHBklR8mA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= diff --git a/internal/outpost/ak/api.go b/internal/outpost/ak/api.go index e68bb1e114..85d202032c 100644 --- a/internal/outpost/ak/api.go +++ b/internal/outpost/ak/api.go @@ -13,6 +13,7 @@ import ( "syscall" "time" + "github.com/avast/retry-go/v4" "github.com/getsentry/sentry-go" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -25,8 +26,6 @@ import ( "goauthentik.io/internal/utils/web" ) -type WSHandler func(ctx context.Context, args map[string]interface{}) - const ConfigLogLevel = "log_level" // APIController main controller which connects to the authentik api via http and ws @@ -43,12 +42,11 @@ type APIController struct { reloadOffset time.Duration - wsConn *websocket.Conn - lastWsReconnect time.Time - wsIsReconnecting bool - wsBackoffMultiplier int - wsHandlers []WSHandler - refreshHandlers []func() + eventConn *websocket.Conn + lastWsReconnect time.Time + wsIsReconnecting bool + eventHandlers []EventHandler + refreshHandlers []func() instanceUUID uuid.UUID } @@ -83,20 +81,19 @@ func NewAPIController(akURL url.URL, token string) *APIController { // Because we don't know the outpost UUID, we simply do a list and pick the first // The service account this token belongs to should only have access to a single outpost - var outposts *api.PaginatedOutpostList - var err error - for { - outposts, _, err = apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute() - - if err == nil { - break - } - - log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds") - time.Sleep(time.Second * 3) - } + outposts, _ := retry.DoWithData[*api.PaginatedOutpostList]( + func() (*api.PaginatedOutpostList, error) { + outposts, _, err := apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute() + return outposts, err + }, + retry.Attempts(0), + retry.Delay(time.Second*3), + retry.OnRetry(func(attempt uint, err error) { + log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds") + }), + ) if len(outposts.Results) < 1 { - panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") + log.Panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") } outpost := outposts.Results[0] @@ -119,17 +116,16 @@ func NewAPIController(akURL url.URL, token string) *APIController { token: token, logger: log, - reloadOffset: time.Duration(rand.Intn(10)) * time.Second, - instanceUUID: uuid.New(), - Outpost: outpost, - wsHandlers: []WSHandler{}, - wsBackoffMultiplier: 1, - refreshHandlers: make([]func(), 0), + reloadOffset: time.Duration(rand.Intn(10)) * time.Second, + instanceUUID: uuid.New(), + Outpost: outpost, + eventHandlers: []EventHandler{}, + refreshHandlers: make([]func(), 0), } ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") - err = ac.initWS(akURL, outpost.Pk) + err = ac.initEvent(akURL, outpost.Pk) if err != nil { - go ac.reconnectWS() + go ac.recentEvents() } ac.configureRefreshSignal() return ac @@ -200,7 +196,7 @@ func (a *APIController) OnRefresh() error { return err } -func (a *APIController) getWebsocketPingArgs() map[string]interface{} { +func (a *APIController) getEventPingArgs() map[string]interface{} { args := map[string]interface{}{ "version": constants.VERSION, "buildHash": constants.BUILD(""), @@ -226,12 +222,12 @@ func (a *APIController) StartBackgroundTasks() error { "build": constants.BUILD(""), }).Set(1) go func() { - a.logger.Debug("Starting WS Handler...") - a.startWSHandler() + a.logger.Debug("Starting Event Handler...") + a.startEventHandler() }() go func() { - a.logger.Debug("Starting WS Health notifier...") - a.startWSHealth() + a.logger.Debug("Starting Event health notifier...") + a.startEventHealth() }() go func() { a.logger.Debug("Starting Interval updater...") diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_event.go similarity index 67% rename from internal/outpost/ak/api_ws.go rename to internal/outpost/ak/api_event.go index a45738f6c1..ba05c74614 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_event.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/avast/retry-go/v4" "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" "goauthentik.io/internal/config" @@ -30,7 +31,7 @@ func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, quer return wsUrl } -func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { +func (ac *APIController) initEvent(akURL url.URL, outpostUUID string) error { query := akURL.Query() query.Set("instance_uuid", ac.instanceUUID.String()) @@ -57,19 +58,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { return err } - ac.wsConn = ws + ac.eventConn = ws // Send hello message with our version - msg := websocketMessage{ - Instruction: WebsocketInstructionHello, - Args: ac.getWebsocketPingArgs(), + msg := Event{ + Instruction: EventKindHello, + Args: ac.getEventPingArgs(), } err = ws.WriteJSON(msg) if err != nil { - ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik") + ac.logger.WithField("logger", "authentik.outpost.events").WithError(err).Warning("Failed to hello to authentik") return err } ac.lastWsReconnect = time.Now() - ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket") + ac.logger.WithField("logger", "authentik.outpost.events").WithField("outpost", outpostUUID).Info("Successfully connected websocket") return nil } @@ -77,19 +78,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { func (ac *APIController) Shutdown() { // Cleanly close the connection by sending a close message and then // waiting (with timeout) for the server to close the connection. - err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err := ac.eventConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { ac.logger.WithError(err).Warning("failed to write close message") return } - err = ac.wsConn.Close() + err = ac.eventConn.Close() if err != nil { ac.logger.WithError(err).Warning("failed to close websocket") } ac.logger.Info("finished shutdown") } -func (ac *APIController) reconnectWS() { +func (ac *APIController) recentEvents() { if ac.wsIsReconnecting { return } @@ -100,46 +101,47 @@ func (ac *APIController) reconnectWS() { Path: strings.ReplaceAll(ac.Client.GetConfig().Servers[0].URL, "api/v3", ""), } attempt := 1 - for { - q := u.Query() - q.Set("attempt", strconv.Itoa(attempt)) - u.RawQuery = q.Encode() - err := ac.initWS(u, ac.Outpost.Pk) - attempt += 1 - if err != nil { - ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier) - time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second) - ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2 - // Limit to 300 seconds (5m) - if ac.wsBackoffMultiplier >= 300 { - ac.wsBackoffMultiplier = 300 + _ = retry.Do( + func() error { + q := u.Query() + q.Set("attempt", strconv.Itoa(attempt)) + u.RawQuery = q.Encode() + err := ac.initEvent(u, ac.Outpost.Pk) + attempt += 1 + if err != nil { + return err } - } else { ac.wsIsReconnecting = false - ac.wsBackoffMultiplier = 1 - return - } - } + return nil + }, + retry.Delay(1*time.Second), + retry.MaxDelay(5*time.Minute), + retry.DelayType(retry.BackOffDelay), + retry.Attempts(0), + retry.OnRetry(func(attempt uint, err error) { + ac.logger.Infof("waiting %d seconds to reconnect", attempt) + }), + ) } -func (ac *APIController) startWSHandler() { - logger := ac.logger.WithField("loop", "ws-handler") +func (ac *APIController) startEventHandler() { + logger := ac.logger.WithField("loop", "event-handler") for { - var wsMsg websocketMessage - if ac.wsConn == nil { - go ac.reconnectWS() + var wsMsg Event + if ac.eventConn == nil { + go ac.recentEvents() time.Sleep(time.Second * 5) continue } - err := ac.wsConn.ReadJSON(&wsMsg) + err := ac.eventConn.ReadJSON(&wsMsg) if err != nil { ConnectionStatus.With(prometheus.Labels{ "outpost_name": ac.Outpost.Name, "outpost_type": ac.Server.Type(), "uuid": ac.instanceUUID.String(), }).Set(0) - logger.WithError(err).Warning("ws read error") - go ac.reconnectWS() + logger.WithError(err).Warning("event read error") + go ac.recentEvents() time.Sleep(time.Second * 5) continue } @@ -149,7 +151,8 @@ func (ac *APIController) startWSHandler() { "uuid": ac.instanceUUID.String(), }).Set(1) switch wsMsg.Instruction { - case WebsocketInstructionTriggerUpdate: + case EventKindAck: + case EventKindTriggerUpdate: time.Sleep(ac.reloadOffset) logger.Debug("Got update trigger...") err := ac.OnRefresh() @@ -164,30 +167,33 @@ func (ac *APIController) startWSHandler() { "build": constants.BUILD(""), }).SetToCurrentTime() } - case WebsocketInstructionProviderSpecific: - for _, h := range ac.wsHandlers { - h(context.Background(), wsMsg.Args) + default: + for _, h := range ac.eventHandlers { + err := h(context.Background(), wsMsg) + if err != nil { + ac.logger.WithError(err).Warning("failed to run event handler") + } } } } } -func (ac *APIController) startWSHealth() { +func (ac *APIController) startEventHealth() { ticker := time.NewTicker(time.Second * 10) for ; true; <-ticker.C { - if ac.wsConn == nil { - go ac.reconnectWS() + if ac.eventConn == nil { + go ac.recentEvents() time.Sleep(time.Second * 5) continue } - err := ac.SendWSHello(map[string]interface{}{}) + err := ac.SendEventHello(map[string]interface{}{}) if err != nil { - ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error") - go ac.reconnectWS() + ac.logger.WithField("loop", "event-health").WithError(err).Warning("event write error") + go ac.recentEvents() time.Sleep(time.Second * 5) continue } else { - ac.logger.WithField("loop", "ws-health").Trace("hello'd") + ac.logger.WithField("loop", "event-health").Trace("hello'd") ConnectionStatus.With(prometheus.Labels{ "outpost_name": ac.Outpost.Name, "outpost_type": ac.Server.Type(), @@ -230,19 +236,19 @@ func (ac *APIController) startIntervalUpdater() { } } -func (a *APIController) AddWSHandler(handler WSHandler) { - a.wsHandlers = append(a.wsHandlers, handler) +func (a *APIController) AddEventHandler(handler EventHandler) { + a.eventHandlers = append(a.eventHandlers, handler) } -func (a *APIController) SendWSHello(args map[string]interface{}) error { - allArgs := a.getWebsocketPingArgs() +func (a *APIController) SendEventHello(args map[string]interface{}) error { + allArgs := a.getEventPingArgs() for key, value := range args { allArgs[key] = value } - aliveMsg := websocketMessage{ - Instruction: WebsocketInstructionHello, + aliveMsg := Event{ + Instruction: EventKindHello, Args: allArgs, } - err := a.wsConn.WriteJSON(aliveMsg) + err := a.eventConn.WriteJSON(aliveMsg) return err } diff --git a/internal/outpost/ak/api_event_msg.go b/internal/outpost/ak/api_event_msg.go new file mode 100644 index 0000000000..6f6d7449fb --- /dev/null +++ b/internal/outpost/ak/api_event_msg.go @@ -0,0 +1,37 @@ +package ak + +import ( + "context" + + "github.com/mitchellh/mapstructure" +) + +type EventKind int + +const ( + // Code used to acknowledge a previous message + EventKindAck EventKind = 0 + // Code used to send a healthcheck keepalive + EventKindHello EventKind = 1 + // Code received to trigger a config update + EventKindTriggerUpdate EventKind = 2 + // Code received to trigger some provider specific function + EventKindProviderSpecific EventKind = 3 + // Code received to identify the end of a session + EventKindSessionEnd EventKind = 4 +) + +type EventHandler func(ctx context.Context, msg Event) error + +type Event struct { + Instruction EventKind `json:"instruction"` + Args interface{} `json:"args"` +} + +func (wm Event) ArgsAs(out interface{}) error { + return mapstructure.Decode(wm.Args, out) +} + +type EventArgsSessionEnd struct { + SessionID string `mapstructure:"session_id"` +} diff --git a/internal/outpost/ak/api_ws_test.go b/internal/outpost/ak/api_event_test.go similarity index 88% rename from internal/outpost/ak/api_ws_test.go rename to internal/outpost/ak/api_event_test.go index f52c353758..6d9b93e2d4 100644 --- a/internal/outpost/ak/api_ws_test.go +++ b/internal/outpost/ak/api_event_test.go @@ -15,7 +15,7 @@ func URLMustParse(u string) *url.URL { return ur } -func TestWebsocketURL(t *testing.T) { +func TestEventWebsocketURL(t *testing.T) { u := URLMustParse("http://localhost:9000?foo=bar") uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" ac := &APIController{} @@ -23,7 +23,7 @@ func TestWebsocketURL(t *testing.T) { assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String()) } -func TestWebsocketURL_Query(t *testing.T) { +func TestEventWebsocketURL_Query(t *testing.T) { u := URLMustParse("http://localhost:9000?foo=bar") uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" ac := &APIController{} @@ -33,7 +33,7 @@ func TestWebsocketURL_Query(t *testing.T) { assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String()) } -func TestWebsocketURL_Subpath(t *testing.T) { +func TestEventWebsocketURL_Subpath(t *testing.T) { u := URLMustParse("http://localhost:9000/foo/bar/") uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" ac := &APIController{} diff --git a/internal/outpost/ak/api_ws_msg.go b/internal/outpost/ak/api_ws_msg.go deleted file mode 100644 index cedecb93d5..0000000000 --- a/internal/outpost/ak/api_ws_msg.go +++ /dev/null @@ -1,19 +0,0 @@ -package ak - -type websocketInstruction int - -const ( - // WebsocketInstructionAck Code used to acknowledge a previous message - WebsocketInstructionAck websocketInstruction = 0 - // WebsocketInstructionHello Code used to send a healthcheck keepalive - WebsocketInstructionHello websocketInstruction = 1 - // WebsocketInstructionTriggerUpdate Code received to trigger a config update - WebsocketInstructionTriggerUpdate websocketInstruction = 2 - // WebsocketInstructionProviderSpecific Code received to trigger some provider specific function - WebsocketInstructionProviderSpecific websocketInstruction = 3 -) - -type websocketMessage struct { - Instruction websocketInstruction `json:"instruction"` - Args map[string]interface{} `json:"args"` -} diff --git a/internal/outpost/ak/test.go b/internal/outpost/ak/test.go index 82a92351ee..d0a8691748 100644 --- a/internal/outpost/ak/test.go +++ b/internal/outpost/ak/test.go @@ -55,11 +55,10 @@ func MockAK(outpost api.Outpost, globalConfig api.Config) *APIController { token: token, logger: log, - reloadOffset: time.Duration(rand.Intn(10)) * time.Second, - instanceUUID: uuid.New(), - Outpost: outpost, - wsBackoffMultiplier: 1, - refreshHandlers: make([]func(), 0), + reloadOffset: time.Duration(rand.Intn(10)) * time.Second, + instanceUUID: uuid.New(), + Outpost: outpost, + refreshHandlers: make([]func(), 0), } ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") return ac diff --git a/internal/outpost/flow/executor.go b/internal/outpost/flow/executor.go index 465715472e..7bfd3cc54b 100644 --- a/internal/outpost/flow/executor.go +++ b/internal/outpost/flow/executor.go @@ -127,7 +127,7 @@ func (fe *FlowExecutor) getAnswer(stage StageComponent) string { return "" } -func (fe *FlowExecutor) GetSession() *http.Cookie { +func (fe *FlowExecutor) SessionCookie() *http.Cookie { return fe.session } diff --git a/internal/outpost/flow/session.go b/internal/outpost/flow/session.go new file mode 100644 index 0000000000..b4e458b8f7 --- /dev/null +++ b/internal/outpost/flow/session.go @@ -0,0 +1,19 @@ +package flow + +import "github.com/golang-jwt/jwt/v5" + +type SessionCookieClaims struct { + jwt.Claims + + SessionID string `json:"sid"` + Authenticated bool `json:"authenticated"` +} + +func (fe *FlowExecutor) Session() *jwt.Token { + sc := fe.SessionCookie() + if sc == nil { + return nil + } + t, _, _ := jwt.NewParser().ParseUnverified(sc.Value, &SessionCookieClaims{}) + return t +} diff --git a/internal/outpost/ldap/bind.go b/internal/outpost/ldap/bind.go index 050dea646f..ae79f69f83 100644 --- a/internal/outpost/ldap/bind.go +++ b/internal/outpost/ldap/bind.go @@ -38,7 +38,14 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD username, err := instance.binder.GetUsername(bindDN) if err == nil { selectedApp = instance.GetAppSlug() - return instance.binder.Bind(username, req) + c, err := instance.binder.Bind(username, req) + if c == ldap.LDAPResultSuccess { + f := instance.GetFlags(req.BindDN) + ls.connectionsSync.Lock() + ls.connections[f.SessionID()] = conn + ls.connectionsSync.Unlock() + } + return c, err } else { req.Log().WithError(err).Debug("Username not for instance") } diff --git a/internal/outpost/ldap/bind/direct/bind.go b/internal/outpost/ldap/bind/direct/bind.go index 36697fbe0f..cc72577716 100644 --- a/internal/outpost/ldap/bind/direct/bind.go +++ b/internal/outpost/ldap/bind/direct/bind.go @@ -27,8 +27,9 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul passed, err := fe.Execute() flags := flags.UserFlags{ - Session: fe.GetSession(), - UserPk: flags.InvalidUserPK, + Session: fe.SessionCookie(), + SessionJWT: fe.Session(), + UserPk: flags.InvalidUserPK, } // only set flags if we don't have flags for this DN yet // as flags are only checked during the bind, we can remember whether a certain DN diff --git a/internal/outpost/ldap/close.go b/internal/outpost/ldap/close.go new file mode 100644 index 0000000000..d050e07639 --- /dev/null +++ b/internal/outpost/ldap/close.go @@ -0,0 +1,20 @@ +package ldap + +import "net" + +func (ls *LDAPServer) Close(dn string, conn net.Conn) error { + ls.connectionsSync.Lock() + defer ls.connectionsSync.Unlock() + key := "" + for k, c := range ls.connections { + if c == conn { + key = k + break + } + } + if key == "" { + return nil + } + delete(ls.connections, key) + return nil +} diff --git a/internal/outpost/ldap/flags/flags.go b/internal/outpost/ldap/flags/flags.go index 60538de2c0..bf7faf473e 100644 --- a/internal/outpost/ldap/flags/flags.go +++ b/internal/outpost/ldap/flags/flags.go @@ -1,16 +1,30 @@ package flags import ( + "crypto/sha256" + "encoding/hex" "net/http" + "github.com/golang-jwt/jwt/v5" "goauthentik.io/api/v3" + "goauthentik.io/internal/outpost/flow" ) const InvalidUserPK = -1 type UserFlags struct { - UserInfo *api.User - UserPk int32 - CanSearch bool - Session *http.Cookie + UserInfo *api.User + UserPk int32 + CanSearch bool + Session *http.Cookie + SessionJWT *jwt.Token +} + +func (uf UserFlags) SessionID() string { + if uf.SessionJWT == nil { + return "" + } + h := sha256.New() + h.Write([]byte(uf.SessionJWT.Claims.(*flow.SessionCookieClaims).SessionID)) + return hex.EncodeToString(h.Sum(nil)) } diff --git a/internal/outpost/ldap/ldap.go b/internal/outpost/ldap/ldap.go index 57d91e4ed4..5bbdc0d167 100644 --- a/internal/outpost/ldap/ldap.go +++ b/internal/outpost/ldap/ldap.go @@ -18,21 +18,26 @@ import ( ) type LDAPServer struct { - s *ldap.Server - log *log.Entry - ac *ak.APIController - cs *ak.CryptoStore - defaultCert *tls.Certificate - providers []*ProviderInstance + s *ldap.Server + log *log.Entry + ac *ak.APIController + cs *ak.CryptoStore + defaultCert *tls.Certificate + providers []*ProviderInstance + connections map[string]net.Conn + connectionsSync sync.Mutex } func NewServer(ac *ak.APIController) ak.Outpost { ls := &LDAPServer{ - log: log.WithField("logger", "authentik.outpost.ldap"), - ac: ac, - cs: ak.NewCryptoStore(ac.Client.CryptoApi), - providers: []*ProviderInstance{}, + log: log.WithField("logger", "authentik.outpost.ldap"), + ac: ac, + cs: ak.NewCryptoStore(ac.Client.CryptoApi), + providers: []*ProviderInstance{}, + connections: map[string]net.Conn{}, + connectionsSync: sync.Mutex{}, } + ac.AddEventHandler(ls.handleWSSessionEnd) s := ldap.NewServer() s.EnforceLDAP = true @@ -50,6 +55,7 @@ func NewServer(ac *ak.APIController) ak.Outpost { s.BindFunc("", ls) s.UnbindFunc("", ls) s.SearchFunc("", ls) + s.CloseFunc("", ls) return ls } @@ -117,3 +123,23 @@ func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) { p.binder.TimerFlowCacheExpiry(ctx) } } + +func (ls *LDAPServer) handleWSSessionEnd(ctx context.Context, msg ak.Event) error { + if msg.Instruction != ak.EventKindSessionEnd { + return nil + } + mmsg := ak.EventArgsSessionEnd{} + err := msg.ArgsAs(&mmsg) + if err != nil { + return err + } + ls.connectionsSync.Lock() + defer ls.connectionsSync.Unlock() + ls.log.Info("Disconnecting session due to session end event") + conn, ok := ls.connections[mmsg.SessionID] + if !ok { + return nil + } + delete(ls.connections, mmsg.SessionID) + return conn.Close() +} diff --git a/internal/outpost/ldap/search/direct/schema.go b/internal/outpost/ldap/search/direct/schema.go index 6fe88ead72..94fb225be2 100644 --- a/internal/outpost/ldap/search/direct/schema.go +++ b/internal/outpost/ldap/search/direct/schema.go @@ -44,38 +44,40 @@ func (ds *DirectSearcher) SearchSubschema(req *search.Request) (ldap.ServerSearc { Name: "attributeTypes", Values: []string{ - "( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )", - "( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", - "( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", - "( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", - "( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )", - "( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", - "( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", - "( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", "( 0.9.2342.19200300.100.1.1 NAME 'uid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 0.9.2342.19200300.100.1.3 NAME 'mail' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 0.9.2342.19200300.100.1.41 NAME 'mobile' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", - "( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.2.102 NAME 'memberOf' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' NO-USER-MODIFICATION )", + "( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.2.131 NAME 'co' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.2.141 NAME 'department' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.4.1 NAME 'name' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )", - "( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.4.221 NAME 'sAMAccountName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.4.261 NAME 'division' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.4.750 NAME 'groupType' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", "( 1.2.840.113556.1.4.782 NAME 'objectCategory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' SINGLE-VALUE )", "( 1.3.6.1.1.1.1.0 NAME 'uidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", "( 1.3.6.1.1.1.1.1 NAME 'gidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", "( 1.3.6.1.1.1.1.12 NAME 'memberUid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.26' )", + "( 2.5.18.1 NAME 'createTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )", + "( 2.5.18.2 NAME 'modifyTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )", + "( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", + "( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", + "( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", + "( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )", + "( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", + "( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", + "( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", + "( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )", + "( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", + "( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", // Custom attributes // Temporarily use 1.3.6.1.4.1.26027.1.1 as a base diff --git a/internal/outpost/proxyv2/proxyv2.go b/internal/outpost/proxyv2/proxyv2.go index 690c6135f9..65c74469b9 100644 --- a/internal/outpost/proxyv2/proxyv2.go +++ b/internal/outpost/proxyv2/proxyv2.go @@ -66,7 +66,7 @@ func NewProxyServer(ac *ak.APIController) ak.Outpost { globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) rootMux.PathPrefix("/").HandlerFunc(s.Handle) - ac.AddWSHandler(s.handleWSMessage) + ac.AddEventHandler(s.handleWSMessage) return s } diff --git a/internal/outpost/proxyv2/ws.go b/internal/outpost/proxyv2/ws.go index 4632a67cbe..c08c354ed9 100644 --- a/internal/outpost/proxyv2/ws.go +++ b/internal/outpost/proxyv2/ws.go @@ -3,48 +3,27 @@ package proxyv2 import ( "context" - "github.com/mitchellh/mapstructure" + "goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/proxyv2/application" ) -type WSProviderSubType string - -const ( - WSProviderSubTypeLogout WSProviderSubType = "logout" -) - -type WSProviderMsg struct { - SubType WSProviderSubType `mapstructure:"sub_type"` - SessionID string `mapstructure:"session_id"` -} - -func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) { - msg := &WSProviderMsg{} - err := mapstructure.Decode(args, &msg) - if err != nil { - return nil, err +func (ps *ProxyServer) handleWSMessage(ctx context.Context, msg ak.Event) error { + if msg.Instruction != ak.EventKindSessionEnd { + return nil } - return msg, nil -} - -func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) { - msg, err := ParseWSProvider(args) + mmsg := ak.EventArgsSessionEnd{} + err := msg.ArgsAs(&mmsg) if err != nil { - ps.log.WithError(err).Warning("invalid provider-specific ws message") - return + return err } - switch msg.SubType { - case WSProviderSubTypeLogout: - for _, p := range ps.apps { - ps.log.WithField("provider", p.Host).Debug("Logging out") - err := p.Logout(ctx, func(c application.Claims) bool { - return c.Sid == msg.SessionID - }) - if err != nil { - ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") - } + for _, p := range ps.apps { + ps.log.WithField("provider", p.Host).Debug("Logging out") + err := p.Logout(ctx, func(c application.Claims) bool { + return c.Sid == mmsg.SessionID + }) + if err != nil { + ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") } - default: - ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type") } + return nil } diff --git a/internal/outpost/rac/rac.go b/internal/outpost/rac/rac.go index 028ded0959..ab50d3a7fb 100644 --- a/internal/outpost/rac/rac.go +++ b/internal/outpost/rac/rac.go @@ -6,7 +6,6 @@ import ( "strconv" "sync" - "github.com/mitchellh/mapstructure" log "github.com/sirupsen/logrus" "github.com/wwt/guac" @@ -30,7 +29,7 @@ func NewServer(ac *ak.APIController) ak.Outpost { connm: sync.RWMutex{}, conns: map[string]connection.Connection{}, } - ac.AddWSHandler(rs.wsHandler) + ac.AddEventHandler(rs.wsHandler) return rs } @@ -52,12 +51,14 @@ func parseIntOrZero(input string) int { return x } -func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) { +func (rs *RACServer) wsHandler(ctx context.Context, msg ak.Event) error { + if msg.Instruction != ak.EventKindProviderSpecific { + return nil + } wsm := WSMessage{} - err := mapstructure.Decode(args, &wsm) + err := msg.ArgsAs(&wsm) if err != nil { - rs.log.WithError(err).Warning("invalid ws message") - return + return err } config := guac.NewGuacamoleConfiguration() config.Protocol = wsm.Protocol @@ -71,23 +72,23 @@ func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) } cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config) if err != nil { - rs.log.WithError(err).Warning("failed to setup connection") - return + return err } cc.OnError = func(err error) { rs.connm.Lock() delete(rs.conns, wsm.ConnID) - _ = rs.ac.SendWSHello(map[string]interface{}{ + _ = rs.ac.SendEventHello(map[string]interface{}{ "active_connections": len(rs.conns), }) rs.connm.Unlock() } rs.connm.Lock() rs.conns[wsm.ConnID] = *cc - _ = rs.ac.SendWSHello(map[string]interface{}{ + _ = rs.ac.SendEventHello(map[string]interface{}{ "active_connections": len(rs.conns), }) rs.connm.Unlock() + return nil } func (rs *RACServer) Start() error {