mirror of
https://github.com/goauthentik/authentik
synced 2026-04-26 17:45:21 +02:00
Compare commits
181 Commits
web/flows/
...
admin/vers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbce9611d2 | ||
|
|
e6643a69cd | ||
|
|
0fdeaee559 | ||
|
|
f9fd1bbf09 | ||
|
|
3ba3b11a76 | ||
|
|
19e558e916 | ||
|
|
e15fadfedd | ||
|
|
52854e61c7 | ||
|
|
53aa0113ca | ||
|
|
9f71face62 | ||
|
|
2fadefb5b4 | ||
|
|
23e92bceae | ||
|
|
1ff2eea20a | ||
|
|
abcd2179bf | ||
|
|
6a4b5850a0 | ||
|
|
821c8c36cd | ||
|
|
8838efe3c0 | ||
|
|
433a4a3037 | ||
|
|
2d69a67e9d | ||
|
|
1294cc64e8 | ||
|
|
910326a05a | ||
|
|
9257b3e570 | ||
|
|
cdd18a7e5a | ||
|
|
88bea46648 | ||
|
|
295090a80b | ||
|
|
bff607a5c3 | ||
|
|
bfb2fb4fcf | ||
|
|
93015b0fce | ||
|
|
9b6c0d3f1a | ||
|
|
66e95ddb20 | ||
|
|
c5d8524a7d | ||
|
|
a4761064c2 | ||
|
|
b0de8bf71f | ||
|
|
32100fd3b9 | ||
|
|
4815e97162 | ||
|
|
dee99c38bb | ||
|
|
a024056b62 | ||
|
|
a8dc21b707 | ||
|
|
7ccda743df | ||
|
|
0c795dd077 | ||
|
|
5df9ed3582 | ||
|
|
a47b4934a5 | ||
|
|
338a6e74f4 | ||
|
|
8897af1048 | ||
|
|
56ec3f7def | ||
|
|
53fd893d91 | ||
|
|
f7d9a8cafe | ||
|
|
f97c1071f3 | ||
|
|
4da1115a7c | ||
|
|
63b1ccd4c3 | ||
|
|
63aa7f4684 | ||
|
|
d997930b60 | ||
|
|
a088a62981 | ||
|
|
118e05f256 | ||
|
|
b30500094f | ||
|
|
21af51ba59 | ||
|
|
87da0497e0 | ||
|
|
87317d6e7f | ||
|
|
071305da18 | ||
|
|
1dc8ed5e55 | ||
|
|
dc8dee985f | ||
|
|
2b20b06baa | ||
|
|
6cab1f85e4 | ||
|
|
f836c38b18 | ||
|
|
07e373e505 | ||
|
|
e361d38978 | ||
|
|
3ba1691db6 | ||
|
|
7c2987ea32 | ||
|
|
4ca88caf07 | ||
|
|
6c939341b0 | ||
|
|
4142584788 | ||
|
|
f6fbafd280 | ||
|
|
7c9555bee8 | ||
|
|
82cd64dfe7 | ||
|
|
28f0b48e33 | ||
|
|
38c02dc490 | ||
|
|
79505969db | ||
|
|
9870888456 | ||
|
|
5c06e1920e | ||
|
|
1506ad8aa4 | ||
|
|
21b6204c90 | ||
|
|
05621735cb | ||
|
|
f9ffd35ab8 | ||
|
|
c3ded3a835 | ||
|
|
7629c22050 | ||
|
|
29a66410fd | ||
|
|
f147d40c5f | ||
|
|
15b556c1be | ||
|
|
522e8a26a2 | ||
|
|
403d762f65 | ||
|
|
cbc65ffd74 | ||
|
|
9a9bafdfb4 | ||
|
|
198d2a1a8a | ||
|
|
239edace16 | ||
|
|
370d5ff0c0 | ||
|
|
635b09621b | ||
|
|
4335498ac5 | ||
|
|
72af009de8 | ||
|
|
3a07d5d829 | ||
|
|
7122891f0f | ||
|
|
c32d6cc75e | ||
|
|
eaf6be74f3 | ||
|
|
c35650afbd | ||
|
|
a1f9ff8b7d | ||
|
|
962f7513ba | ||
|
|
0ec5ea69ef | ||
|
|
d8a3098329 | ||
|
|
80ad97b28d | ||
|
|
cd98767dbc | ||
|
|
30f09e8c45 | ||
|
|
154bcb58a6 | ||
|
|
597945edf1 | ||
|
|
38d6e39fe0 | ||
|
|
1a6065f72a | ||
|
|
d07e0f015d | ||
|
|
7f931917fa | ||
|
|
d7fb684292 | ||
|
|
bd0fa7be98 | ||
|
|
2907808a7e | ||
|
|
c53016b2e5 | ||
|
|
4479587baa | ||
|
|
08d24a1871 | ||
|
|
42ea8bb3ed | ||
|
|
c9a07fa18d | ||
|
|
4130446cbc | ||
|
|
b4aecbd782 | ||
|
|
981d2af109 | ||
|
|
db96e13813 | ||
|
|
3d39cc4974 | ||
|
|
d36ec31224 | ||
|
|
bb7a2002f2 | ||
|
|
8fff802936 | ||
|
|
0f3fb9f93c | ||
|
|
1e76d1f883 | ||
|
|
140d9fe95c | ||
|
|
67eacbe860 | ||
|
|
435b815617 | ||
|
|
0459feeb8a | ||
|
|
4e6e730014 | ||
|
|
1231fc8237 | ||
|
|
b7f320d7cc | ||
|
|
35073b03ac | ||
|
|
b3b8b8bb1c | ||
|
|
17ee42f98f | ||
|
|
e8f95a4b08 | ||
|
|
decc0c1ae1 | ||
|
|
716bfa9043 | ||
|
|
4d8feb15e3 | ||
|
|
d50eab08e8 | ||
|
|
09b2a2bd4f | ||
|
|
210d9cf31c | ||
|
|
a0291a1b32 | ||
|
|
790ae0c3d8 | ||
|
|
8fc744fb56 | ||
|
|
392011cac4 | ||
|
|
15316b6bae | ||
|
|
dccb1d01f0 | ||
|
|
e8cd762c6e | ||
|
|
12847d9a87 | ||
|
|
6c4cb06825 | ||
|
|
aa8e971477 | ||
|
|
6c02d5a316 | ||
|
|
2f3259bf13 | ||
|
|
8b7a538419 | ||
|
|
d0127d83c9 | ||
|
|
262ca4aea9 | ||
|
|
9923cb73a6 | ||
|
|
b58a8774d4 | ||
|
|
bf6a37a5dc | ||
|
|
1eda16cbd6 | ||
|
|
8c3397e5f2 | ||
|
|
974b4d5c82 | ||
|
|
00daba0d0c | ||
|
|
63d547194c | ||
|
|
ec171bd282 | ||
|
|
155fa433b3 | ||
|
|
7a88fd5b6b | ||
|
|
7d9fb85827 | ||
|
|
0021e5fa25 | ||
|
|
6919838c12 | ||
|
|
9841d976e1 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -31,4 +31,4 @@ If changes to the frontend have been made
|
|||||||
If applicable
|
If applicable
|
||||||
|
|
||||||
- [ ] The documentation has been updated
|
- [ ] The documentation has been updated
|
||||||
- [ ] The documentation has been formatted (`make website`)
|
- [ ] The documentation has been formatted (`make docs`)
|
||||||
|
|||||||
4
.github/workflows/api-ts-publish.yml
vendored
4
.github/workflows/api-ts-publish.yml
vendored
@@ -27,8 +27,8 @@ jobs:
|
|||||||
- name: Publish package
|
- name: Publish package
|
||||||
working-directory: gen-ts-api/
|
working-directory: gen-ts-api/
|
||||||
run: |
|
run: |
|
||||||
npm ci
|
npm i
|
||||||
npm publish
|
npm publish --tag generated
|
||||||
env:
|
env:
|
||||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_PUBLISH_TOKEN }}
|
NODE_AUTH_TOKEN: ${{ secrets.NPM_PUBLISH_TOKEN }}
|
||||||
- name: Upgrade /web
|
- name: Upgrade /web
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
|||||||
go build -o /go/authentik ./cmd/server
|
go build -o /go/authentik ./cmd/server
|
||||||
|
|
||||||
# Stage 3: MaxMind GeoIP
|
# Stage 3: MaxMind GeoIP
|
||||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.0 AS geoip
|
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.1 AS geoip
|
||||||
|
|
||||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||||
ENV GEOIPUPDATE_VERBOSE="1"
|
ENV GEOIPUPDATE_VERBOSE="1"
|
||||||
@@ -75,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
|||||||
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||||
|
|
||||||
# Stage 4: Download uv
|
# Stage 4: Download uv
|
||||||
FROM ghcr.io/astral-sh/uv:0.7.19 AS uv
|
FROM ghcr.io/astral-sh/uv:0.7.21 AS uv
|
||||||
# Stage 5: Base python image
|
# Stage 5: Base python image
|
||||||
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
|
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
|
||||||
|
|
||||||
|
|||||||
43
Makefile
43
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: gen dev-reset all clean test web website
|
.PHONY: gen dev-reset all clean test web docs
|
||||||
|
|
||||||
SHELL := /usr/bin/env bash
|
SHELL := /usr/bin/env bash
|
||||||
.SHELLFLAGS += ${SHELLFLAGS} -e -o pipefail
|
.SHELLFLAGS += ${SHELLFLAGS} -e -o pipefail
|
||||||
@@ -73,7 +73,7 @@ core-i18n-extract:
|
|||||||
--ignore website \
|
--ignore website \
|
||||||
-l en
|
-l en
|
||||||
|
|
||||||
install: web-install website-install core-install ## Install all requires dependencies for `web`, `website` and `core`
|
install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core`
|
||||||
|
|
||||||
dev-drop-db:
|
dev-drop-db:
|
||||||
dropdb -U ${pg_user} -h ${pg_host} ${pg_name}
|
dropdb -U ${pg_user} -h ${pg_host} ${pg_name}
|
||||||
@@ -183,18 +183,23 @@ gen-dev-config: ## Generate a local development config file
|
|||||||
|
|
||||||
gen: gen-build gen-client-ts
|
gen: gen-build gen-client-ts
|
||||||
|
|
||||||
|
#########################
|
||||||
|
## Node.js
|
||||||
|
#########################
|
||||||
|
|
||||||
|
node-install: ## Install the necessary libraries to build Node.js packages
|
||||||
|
npm ci
|
||||||
|
npm ci --prefix web
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
## Web
|
## Web
|
||||||
#########################
|
#########################
|
||||||
|
|
||||||
web-build: web-install ## Build the Authentik UI
|
web-build: node-install ## Build the Authentik UI
|
||||||
cd web && npm run build
|
cd web && npm run build
|
||||||
|
|
||||||
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
|
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
|
||||||
|
|
||||||
web-install: ## Install the necessary libraries to build the Authentik UI
|
|
||||||
cd web && npm ci
|
|
||||||
|
|
||||||
web-test: ## Run tests for the Authentik UI
|
web-test: ## Run tests for the Authentik UI
|
||||||
cd web && npm run test
|
cd web && npm run test
|
||||||
|
|
||||||
@@ -221,22 +226,28 @@ web-i18n-extract:
|
|||||||
cd web && npm run extract-locales
|
cd web && npm run extract-locales
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
## Website
|
## Docs
|
||||||
#########################
|
#########################
|
||||||
|
|
||||||
website: website-lint-fix website-build ## Automatically fix formatting issues in the Authentik website/docs source code, lint the code, and compile it
|
docs: docs-lint-fix docs-build ## Automatically fix formatting issues in the Authentik docs source code, lint the code, and compile it
|
||||||
|
|
||||||
website-install:
|
docs-install:
|
||||||
cd website && npm ci
|
npm ci --prefix website
|
||||||
|
|
||||||
website-lint-fix: lint-codespell
|
docs-lint-fix: lint-codespell
|
||||||
cd website && npm run prettier
|
npm run prettier --prefix website
|
||||||
|
|
||||||
website-build:
|
docs-build:
|
||||||
cd website && npm run build
|
npm run build --prefix website
|
||||||
|
|
||||||
website-watch: ## Build and watch the documentation website, updating automatically
|
docs-watch: ## Build and watch the topics documentation
|
||||||
cd website && npm run watch
|
npm run start --prefix website
|
||||||
|
|
||||||
|
docs-integrations-build:
|
||||||
|
npm run build --prefix website -w integrations
|
||||||
|
|
||||||
|
docs-integrations-watch: ## Build and watch the Integrations documentation
|
||||||
|
npm run start --prefix website -w integrations
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
## Docker
|
## Docker
|
||||||
|
|||||||
@@ -42,7 +42,11 @@ class Exporter:
|
|||||||
if model in self.excluded_models:
|
if model in self.excluded_models:
|
||||||
continue
|
continue
|
||||||
for obj in self.get_model_instances(model):
|
for obj in self.get_model_instances(model):
|
||||||
yield BlueprintEntry.from_model(obj)
|
yield BlueprintEntry.from_model(self.alter_model(obj))
|
||||||
|
|
||||||
|
def alter_model(self, model: Model):
|
||||||
|
"""Hook to modify the model before exporting"""
|
||||||
|
return model
|
||||||
|
|
||||||
def get_model_instances(self, model: type[Model]) -> QuerySet:
|
def get_model_instances(self, model: type[Model]) -> QuerySet:
|
||||||
"""Return a queryset for `model`. Can be used to filter some
|
"""Return a queryset for `model`. Can be used to filter some
|
||||||
|
|||||||
@@ -52,6 +52,27 @@ class TestBrands(APITestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_brand_subdomain_same_suffix(self):
|
||||||
|
"""Test Current brand API"""
|
||||||
|
Brand.objects.all().delete()
|
||||||
|
Brand.objects.create(domain="bar.baz", branding_title="custom")
|
||||||
|
Brand.objects.create(domain="foo.bar.baz", branding_title="custom")
|
||||||
|
self.assertJSONEqual(
|
||||||
|
self.client.get(
|
||||||
|
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
|
||||||
|
).content.decode(),
|
||||||
|
{
|
||||||
|
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||||
|
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||||
|
"branding_title": "custom",
|
||||||
|
"branding_custom_css": "",
|
||||||
|
"matched_domain": "foo.bar.baz",
|
||||||
|
"ui_footer_links": [],
|
||||||
|
"ui_theme": Themes.AUTOMATIC,
|
||||||
|
"default_locale": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_fallback(self):
|
def test_fallback(self):
|
||||||
"""Test fallback brand"""
|
"""Test fallback brand"""
|
||||||
Brand.objects.all().delete()
|
Brand.objects.all().delete()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from django.db.models import F, Q
|
from django.db.models import F, Q
|
||||||
from django.db.models import Value as V
|
from django.db.models import Value as V
|
||||||
|
from django.db.models.functions import Length
|
||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
from django.utils.html import _json_script_escapes
|
from django.utils.html import _json_script_escapes
|
||||||
from django.utils.safestring import mark_safe
|
from django.utils.safestring import mark_safe
|
||||||
@@ -20,9 +21,9 @@ DEFAULT_BRAND = Brand(domain="fallback")
|
|||||||
def get_brand_for_request(request: HttpRequest) -> Brand:
|
def get_brand_for_request(request: HttpRequest) -> Brand:
|
||||||
"""Get brand object for current request"""
|
"""Get brand object for current request"""
|
||||||
db_brands = (
|
db_brands = (
|
||||||
Brand.objects.annotate(host_domain=V(request.get_host()))
|
Brand.objects.annotate(host_domain=V(request.get_host()), match_length=Length("domain"))
|
||||||
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
|
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
|
||||||
.order_by("default")
|
.order_by("-match_length", "default")
|
||||||
)
|
)
|
||||||
brands = list(db_brands.all())
|
brands = list(db_brands.all())
|
||||||
if len(brands) < 1:
|
if len(brands) < 1:
|
||||||
|
|||||||
@@ -149,10 +149,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
|||||||
return applications
|
return applications
|
||||||
|
|
||||||
def _filter_applications_with_launch_url(
|
def _filter_applications_with_launch_url(
|
||||||
self, pagined_apps: Iterator[Application]
|
self, paginated_apps: Iterator[Application]
|
||||||
) -> list[Application]:
|
) -> list[Application]:
|
||||||
applications = []
|
applications = []
|
||||||
for app in pagined_apps:
|
for app in paginated_apps:
|
||||||
if app.get_launch_url():
|
if app.get_launch_url():
|
||||||
applications.append(app)
|
applications.append(app)
|
||||||
return applications
|
return applications
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from authentik.core.expression.exceptions import SkipObjectException
|
|||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.policies.types import PolicyRequest
|
from authentik.policies.types import PolicyRequest
|
||||||
|
|
||||||
PROPERTY_MAPPING_TIME = Histogram(
|
PROPERTY_MAPPING_TIME = Histogram(
|
||||||
@@ -69,12 +68,11 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
|||||||
# For dry-run requests we don't save exceptions
|
# For dry-run requests we don't save exceptions
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
return
|
return
|
||||||
error_string = exception_to_string(exc)
|
|
||||||
event = Event.new(
|
event = Event.new(
|
||||||
EventAction.PROPERTY_MAPPING_EXCEPTION,
|
EventAction.PROPERTY_MAPPING_EXCEPTION,
|
||||||
expression=expression_source,
|
expression=expression_source,
|
||||||
message=error_string,
|
message="Failed to execute property mapping",
|
||||||
)
|
).with_exception(exc)
|
||||||
if "request" in self._context:
|
if "request" in self._context:
|
||||||
req: PolicyRequest = self._context["request"]
|
req: PolicyRequest = self._context["request"]
|
||||||
if req.http_request:
|
if req.http_request:
|
||||||
|
|||||||
24
authentik/core/migrations/0049_alter_token_options.py
Normal file
24
authentik/core/migrations/0049_alter_token_options.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Generated by Django 5.1.11 on 2025-07-03 13:08
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_core", "0048_delete_oldauthenticatedsession_content_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterModelOptions(
|
||||||
|
name="token",
|
||||||
|
options={
|
||||||
|
"permissions": [
|
||||||
|
("view_token_key", "View token's key"),
|
||||||
|
("set_token_key", "Set a token's key"),
|
||||||
|
],
|
||||||
|
"verbose_name": "Token",
|
||||||
|
"verbose_name_plural": "Tokens",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -953,7 +953,10 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
|
|||||||
models.Index(fields=["identifier"]),
|
models.Index(fields=["identifier"]),
|
||||||
models.Index(fields=["key"]),
|
models.Index(fields=["key"]),
|
||||||
]
|
]
|
||||||
permissions = [("view_token_key", _("View token's key"))]
|
permissions = [
|
||||||
|
("view_token_key", _("View token's key")),
|
||||||
|
("set_token_key", _("Set a token's key")),
|
||||||
|
]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
description = f"{self.identifier}"
|
description = f"{self.identifier}"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from django.http import HttpResponse
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from drf_spectacular.types import OpenApiTypes
|
from drf_spectacular.types import OpenApiTypes
|
||||||
@@ -10,16 +11,21 @@ from rest_framework.decorators import action
|
|||||||
from rest_framework.exceptions import ValidationError
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.fields import CharField, IntegerField
|
from rest_framework.fields import CharField, IntegerField
|
||||||
from rest_framework.permissions import IsAuthenticated
|
from rest_framework.permissions import IsAuthenticated
|
||||||
|
from rest_framework.renderers import BaseRenderer
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.validators import UniqueValidator
|
||||||
|
from rest_framework.views import APIView
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||||
from authentik.core.models import User, UserTypes
|
from authentik.core.models import User, UserTypes
|
||||||
|
from authentik.enterprise.bundle import generate_support_bundle
|
||||||
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
|
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
|
||||||
from authentik.enterprise.models import License
|
from authentik.enterprise.models import License
|
||||||
from authentik.rbac.decorators import permission_required
|
from authentik.rbac.decorators import permission_required
|
||||||
|
from authentik.rbac.permissions import HasPermission
|
||||||
from authentik.tenants.utils import get_unique_identifier
|
from authentik.tenants.utils import get_unique_identifier
|
||||||
|
|
||||||
|
|
||||||
@@ -53,6 +59,7 @@ class LicenseSerializer(ModelSerializer):
|
|||||||
"external_users",
|
"external_users",
|
||||||
]
|
]
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
|
"key": {"validators": [UniqueValidator(queryset=License.objects.all())]},
|
||||||
"name": {"read_only": True},
|
"name": {"read_only": True},
|
||||||
"expiry": {"read_only": True},
|
"expiry": {"read_only": True},
|
||||||
"internal_users": {"read_only": True},
|
"internal_users": {"read_only": True},
|
||||||
@@ -145,3 +152,24 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
|||||||
)
|
)
|
||||||
response.is_valid(raise_exception=True)
|
response.is_valid(raise_exception=True)
|
||||||
return Response(response.data)
|
return Response(response.data)
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryRenderer(BaseRenderer):
|
||||||
|
media_type = "application/gzip"
|
||||||
|
format = "bin"
|
||||||
|
|
||||||
|
|
||||||
|
class SupportBundleView(APIView):
|
||||||
|
"""Generate a support bundle."""
|
||||||
|
|
||||||
|
permission_classes = [HasPermission("authentik_rbac.view_system_info")]
|
||||||
|
pagination_class = None
|
||||||
|
filter_backends = []
|
||||||
|
renderer_classes = [BinaryRenderer]
|
||||||
|
|
||||||
|
@extend_schema(responses=bytes, request=None)
|
||||||
|
def post(self, request: Request) -> Response:
|
||||||
|
"""Generate a support bundle."""
|
||||||
|
response = HttpResponse(generate_support_bundle(), content_type=BinaryRenderer.media_type)
|
||||||
|
response["Content-Disposition"] = 'attachment; filename="authentik_support.tgz"'
|
||||||
|
return response
|
||||||
|
|||||||
@@ -65,13 +65,17 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
|||||||
data[field.name] = deepcopy(field_value)
|
data[field.name] = deepcopy(field_value)
|
||||||
return cleanse_dict(data)
|
return cleanse_dict(data)
|
||||||
|
|
||||||
def diff(self, before: dict, after: dict) -> dict:
|
def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
|
||||||
"""Generate diff between dicts"""
|
"""Generate diff between dicts"""
|
||||||
diff = {}
|
diff = {}
|
||||||
for key, value in before.items():
|
for key, value in before.items():
|
||||||
|
if update_fields and key not in update_fields:
|
||||||
|
continue
|
||||||
if after.get(key) != value:
|
if after.get(key) != value:
|
||||||
diff[key] = {"previous_value": value, "new_value": after.get(key)}
|
diff[key] = {"previous_value": value, "new_value": after.get(key)}
|
||||||
for key, value in after.items():
|
for key, value in after.items():
|
||||||
|
if update_fields and key not in update_fields:
|
||||||
|
continue
|
||||||
if key not in before and key not in diff and before.get(key) != value:
|
if key not in before and key not in diff and before.get(key) != value:
|
||||||
diff[key] = {"previous_value": before.get(key), "new_value": value}
|
diff[key] = {"previous_value": before.get(key), "new_value": value}
|
||||||
return sanitize_item(diff)
|
return sanitize_item(diff)
|
||||||
@@ -95,6 +99,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
|||||||
instance: Model,
|
instance: Model,
|
||||||
created: bool,
|
created: bool,
|
||||||
thread_kwargs: dict | None = None,
|
thread_kwargs: dict | None = None,
|
||||||
|
update_fields: list[str] | None = None,
|
||||||
**_,
|
**_,
|
||||||
):
|
):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
@@ -108,7 +113,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
|||||||
prev_state = {}
|
prev_state = {}
|
||||||
# Get current state
|
# Get current state
|
||||||
new_state = self.serialize_simple(instance)
|
new_state = self.serialize_simple(instance)
|
||||||
diff = self.diff(prev_state, new_state)
|
diff = self.diff(prev_state, new_state, update_fields)
|
||||||
thread_kwargs["diff"] = diff
|
thread_kwargs["diff"] = diff
|
||||||
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
|
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from rest_framework.test import APITestCase
|
|||||||
|
|
||||||
from authentik.core.models import Group, User
|
from authentik.core.models import Group, User
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
|
from authentik.enterprise.audit.middleware import EnterpriseAuditMiddleware
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.events.utils import sanitize_item
|
from authentik.events.utils import sanitize_item
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
@@ -208,3 +209,23 @@ class TestEnterpriseAudit(APITestCase):
|
|||||||
diff,
|
diff,
|
||||||
{"users": {"remove": [user.pk]}},
|
{"users": {"remove": [user.pk]}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||||
|
PropertyMock(return_value=True),
|
||||||
|
)
|
||||||
|
def test_diff_update_fields(self):
|
||||||
|
"""Test update audit log"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
diff = EnterpriseAuditMiddleware(None).diff(
|
||||||
|
{
|
||||||
|
"foo": "bar",
|
||||||
|
"is_active": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"foo": "baz",
|
||||||
|
"is_active": True,
|
||||||
|
},
|
||||||
|
update_fields=["is_active"],
|
||||||
|
)
|
||||||
|
self.assertEqual(diff, {"is_active": {"new_value": True, "previous_value": False}})
|
||||||
|
|||||||
53
authentik/enterprise/bundle.py
Normal file
53
authentik/enterprise/bundle.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import re
|
||||||
|
from io import BytesIO
|
||||||
|
from tarfile import TarInfo, open
|
||||||
|
|
||||||
|
from django.db.models import Model
|
||||||
|
from django.db.models.fields import CharField, SlugField, TextField
|
||||||
|
from django.db.models.fields.json import JSONField
|
||||||
|
|
||||||
|
from authentik.blueprints.v1.exporter import Exporter
|
||||||
|
from authentik.core.models import User
|
||||||
|
from lifecycle.support import encrypt, generate
|
||||||
|
|
||||||
|
SENSITIVE_VALUE_PLACEHOLDER = "<REDACTED>"
|
||||||
|
|
||||||
|
|
||||||
|
class SupportExporter(Exporter):
|
||||||
|
"""Blueprint exporter which censors sensitive model attributes"""
|
||||||
|
|
||||||
|
sensitive_fields = re.compile(
|
||||||
|
# Partially taken from Django's SafeExceptionReporterFilter
|
||||||
|
"API|AUTH|TOKEN|KEY|SECRET|PASS|SIGNATURE|CREDENTIALS",
|
||||||
|
re.I,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.excluded_models.append(User)
|
||||||
|
|
||||||
|
def alter_model(self, model: Model):
|
||||||
|
for field in model._meta.fields:
|
||||||
|
if not self.sensitive_fields.search(field.name):
|
||||||
|
continue
|
||||||
|
if isinstance(field, TextField | CharField | SlugField):
|
||||||
|
setattr(model, field.name, SENSITIVE_VALUE_PLACEHOLDER)
|
||||||
|
elif isinstance(field, JSONField):
|
||||||
|
setattr(model, field.name, {})
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def generate_support_bundle():
|
||||||
|
fh = BytesIO()
|
||||||
|
exporter = SupportExporter()
|
||||||
|
files = {
|
||||||
|
"authentik/support.jwe": encrypt(generate()),
|
||||||
|
"authentik/blueprint.yaml": exporter.export_to_string(),
|
||||||
|
}
|
||||||
|
with open(fileobj=fh, mode="w:gz") as tar:
|
||||||
|
for path, file in files.items():
|
||||||
|
info = TarInfo(path)
|
||||||
|
info.size = len(file)
|
||||||
|
tar.addfile(info, BytesIO(file.encode()))
|
||||||
|
final_data = fh.getvalue()
|
||||||
|
return final_data
|
||||||
@@ -6,7 +6,7 @@ from djangoql.ast import Name
|
|||||||
from djangoql.exceptions import DjangoQLError
|
from djangoql.exceptions import DjangoQLError
|
||||||
from djangoql.queryset import apply_search
|
from djangoql.queryset import apply_search
|
||||||
from djangoql.schema import DjangoQLSchema
|
from djangoql.schema import DjangoQLSchema
|
||||||
from rest_framework.filters import BaseFilterBackend, SearchFilter
|
from rest_framework.filters import SearchFilter
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class BaseSchema(DjangoQLSchema):
|
|||||||
return super().resolve_name(name)
|
return super().resolve_name(name)
|
||||||
|
|
||||||
|
|
||||||
class QLSearch(BaseFilterBackend):
|
class QLSearch(SearchFilter):
|
||||||
"""rest_framework search filter which uses DjangoQL"""
|
"""rest_framework search filter which uses DjangoQL"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from authentik.stages.authenticator.models import Device
|
|||||||
|
|
||||||
|
|
||||||
class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||||
"""Setup Google Chrome Device-trust connection"""
|
"""Setup Google Chrome Device Trust connection"""
|
||||||
|
|
||||||
credentials = models.JSONField()
|
credentials = models.JSONField()
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
|
|||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||||
|
from authentik.stages.user_login.stage import PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE
|
||||||
|
|
||||||
# Header we get from chrome that initiates verified access
|
# Header we get from chrome that initiates verified access
|
||||||
HEADER_DEVICE_TRUST = "X-Device-Trust"
|
HEADER_DEVICE_TRUST = "X-Device-Trust"
|
||||||
@@ -27,6 +28,8 @@ HEADER_ACCESS_CHALLENGE_RESPONSE = "X-Verified-Access-Challenge-Response"
|
|||||||
# Header value for x-device-trust that initiates the flow
|
# Header value for x-device-trust that initiates the flow
|
||||||
DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess"
|
DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess"
|
||||||
|
|
||||||
|
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS = "endpoints"
|
||||||
|
|
||||||
|
|
||||||
@method_decorator(xframe_options_sameorigin, name="dispatch")
|
@method_decorator(xframe_options_sameorigin, name="dispatch")
|
||||||
class GoogleChromeDeviceTrustConnector(View):
|
class GoogleChromeDeviceTrustConnector(View):
|
||||||
@@ -81,7 +84,14 @@ class GoogleChromeDeviceTrustConnector(View):
|
|||||||
)
|
)
|
||||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint")
|
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint")
|
||||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault("endpoints", [])
|
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS]["endpoints"].append(response)
|
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS, []
|
||||||
|
)
|
||||||
|
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS][PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS].append(
|
||||||
|
response
|
||||||
|
)
|
||||||
|
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||||
|
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, True
|
||||||
|
)
|
||||||
request.session[SESSION_KEY_PLAN] = flow_plan
|
request.session[SESSION_KEY_PLAN] = flow_plan
|
||||||
return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html")
|
return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html")
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
"""API URLs"""
|
"""API URLs"""
|
||||||
|
|
||||||
from authentik.enterprise.api import LicenseViewSet
|
from django.urls import path
|
||||||
|
|
||||||
|
from authentik.enterprise.api import LicenseViewSet, SupportBundleView
|
||||||
|
|
||||||
api_urlpatterns = [
|
api_urlpatterns = [
|
||||||
("enterprise/license", LicenseViewSet),
|
("enterprise/license", LicenseViewSet),
|
||||||
|
path(
|
||||||
|
"enterprise/support_bundle/", SupportBundleView.as_view(), name="enterprise_support_bundle"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from authentik.core.models import Group, User
|
|||||||
from authentik.events.models import Event, EventAction, Notification
|
from authentik.events.models import Event, EventAction, Notification
|
||||||
from authentik.events.utils import model_to_dict
|
from authentik.events.utils import model_to_dict
|
||||||
from authentik.lib.sentry import should_ignore_exception
|
from authentik.lib.sentry import should_ignore_exception
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
from authentik.lib.utils.errors import exception_to_dict
|
||||||
from authentik.stages.authenticator_static.models import StaticToken
|
from authentik.stages.authenticator_static.models import StaticToken
|
||||||
|
|
||||||
IGNORED_MODELS = tuple(
|
IGNORED_MODELS = tuple(
|
||||||
@@ -170,14 +170,16 @@ class AuditMiddleware:
|
|||||||
thread = EventNewThread(
|
thread = EventNewThread(
|
||||||
EventAction.SUSPICIOUS_REQUEST,
|
EventAction.SUSPICIOUS_REQUEST,
|
||||||
request,
|
request,
|
||||||
message=exception_to_string(exception),
|
message=str(exception),
|
||||||
|
exception=exception_to_dict(exception),
|
||||||
)
|
)
|
||||||
thread.run()
|
thread.run()
|
||||||
elif not should_ignore_exception(exception):
|
elif not should_ignore_exception(exception):
|
||||||
thread = EventNewThread(
|
thread = EventNewThread(
|
||||||
EventAction.SYSTEM_EXCEPTION,
|
EventAction.SYSTEM_EXCEPTION,
|
||||||
request,
|
request,
|
||||||
message=exception_to_string(exception),
|
message=str(exception),
|
||||||
|
exception=exception_to_dict(exception),
|
||||||
)
|
)
|
||||||
thread.run()
|
thread.run()
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from authentik.events.utils import (
|
|||||||
)
|
)
|
||||||
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
|
from authentik.lib.utils.errors import exception_to_dict
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.policies.models import PolicyBindingModel
|
from authentik.policies.models import PolicyBindingModel
|
||||||
@@ -163,6 +164,12 @@ class Event(SerializerModel, ExpiringModel):
|
|||||||
event = Event(action=action, app=app, context=cleaned_kwargs)
|
event = Event(action=action, app=app, context=cleaned_kwargs)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
def with_exception(self, exc: Exception) -> "Event":
|
||||||
|
"""Add data from 'exc' to the event in a database-saveable format"""
|
||||||
|
self.context.setdefault("message", str(exc))
|
||||||
|
self.context["exception"] = exception_to_dict(exc)
|
||||||
|
return self
|
||||||
|
|
||||||
def set_user(self, user: User) -> "Event":
|
def set_user(self, user: User) -> "Event":
|
||||||
"""Set `.user` based on user, ensuring the correct attributes are copied.
|
"""Set `.user` based on user, ensuring the correct attributes are copied.
|
||||||
This should only be used when self.from_http is *not* used."""
|
This should only be used when self.from_http is *not* used."""
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ class SystemTask(TenantTask):
|
|||||||
)
|
)
|
||||||
Event.new(
|
Event.new(
|
||||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||||
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
message=f"Task {self.__name__} encountered an error",
|
||||||
).save()
|
).with_exception(exc).save()
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
|||||||
policy_engine.mode = PolicyEngineMode.MODE_ANY
|
policy_engine.mode = PolicyEngineMode.MODE_ANY
|
||||||
policy_engine.empty_result = False
|
policy_engine.empty_result = False
|
||||||
policy_engine.use_cache = False
|
policy_engine.use_cache = False
|
||||||
|
policy_engine.request.obj = event
|
||||||
policy_engine.request.context["event"] = event
|
policy_engine.request.context["event"] = event
|
||||||
policy_engine.build()
|
policy_engine.build()
|
||||||
result = policy_engine.result
|
result = policy_engine.result
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ from authentik.flows.planner import (
|
|||||||
)
|
)
|
||||||
from authentik.flows.stage import AccessDeniedStage, StageView
|
from authentik.flows.stage import AccessDeniedStage, StageView
|
||||||
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
|
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.lib.utils.reflection import all_subclasses, class_to_path
|
from authentik.lib.utils.reflection import all_subclasses, class_to_path
|
||||||
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
|
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
@@ -239,8 +238,8 @@ class FlowExecutorView(APIView):
|
|||||||
capture_exception(exc)
|
capture_exception(exc)
|
||||||
Event.new(
|
Event.new(
|
||||||
action=EventAction.SYSTEM_EXCEPTION,
|
action=EventAction.SYSTEM_EXCEPTION,
|
||||||
message=exception_to_string(exc),
|
message="System exception during flow execution.",
|
||||||
).from_http(self.request)
|
).with_exception(exc).from_http(self.request)
|
||||||
challenge = FlowErrorChallenge(self.request, exc)
|
challenge = FlowErrorChallenge(self.request, exc)
|
||||||
challenge.is_valid(raise_exception=True)
|
challenge.is_valid(raise_exception=True)
|
||||||
return to_stage_response(self.request, HttpChallengeResponse(challenge))
|
return to_stage_response(self.request, HttpChallengeResponse(challenge))
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from authentik.events.models import Event, EventAction
|
|||||||
from authentik.lib.expression.exceptions import ControlFlowException
|
from authentik.lib.expression.exceptions import ControlFlowException
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
|
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
@@ -106,9 +105,9 @@ class BaseOutgoingSyncClient[
|
|||||||
# Value error can be raised when assigning invalid data to an attribute
|
# Value error can be raised when assigning invalid data to an attribute
|
||||||
Event.new(
|
Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
message="Failed to evaluate property-mapping",
|
||||||
mapping=exc.mapping,
|
mapping=exc.mapping,
|
||||||
).save()
|
).with_exception(exc).save()
|
||||||
raise StopSync(exc, obj, exc.mapping) from exc
|
raise StopSync(exc, obj, exc.mapping) from exc
|
||||||
if not raw_final_object:
|
if not raw_final_object:
|
||||||
raise StopSync(ValueError("No mappings configured"), obj)
|
raise StopSync(ValueError("No mappings configured"), obj)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from traceback import extract_tb
|
from traceback import extract_tb
|
||||||
|
|
||||||
|
from structlog.tracebacks import ExceptionDictTransformer
|
||||||
|
|
||||||
from authentik.lib.utils.reflection import class_to_path
|
from authentik.lib.utils.reflection import class_to_path
|
||||||
|
|
||||||
TRACEBACK_HEADER = "Traceback (most recent call last):"
|
TRACEBACK_HEADER = "Traceback (most recent call last):"
|
||||||
@@ -17,3 +19,8 @@ def exception_to_string(exc: Exception) -> str:
|
|||||||
f"{class_to_path(exc.__class__)}: {str(exc)}",
|
f"{class_to_path(exc.__class__)}: {str(exc)}",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def exception_to_dict(exc: Exception) -> dict:
|
||||||
|
"""Format exception as a dictionary"""
|
||||||
|
return ExceptionDictTransformer()((type(exc), exc, exc.__traceback__))
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from authentik.events.models import Event, EventAction
|
|||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.outposts.controllers.k8s.utils import get_namespace
|
from authentik.outposts.controllers.k8s.utils import get_namespace
|
||||||
|
|
||||||
OUR_VERSION = parse(__version__)
|
OUR_VERSION = parse(__version__)
|
||||||
@@ -326,9 +325,8 @@ class Outpost(SerializerModel, ManagedModel):
|
|||||||
"While setting the permissions for the service-account, a "
|
"While setting the permissions for the service-account, a "
|
||||||
"permission was not found: Check "
|
"permission was not found: Check "
|
||||||
"https://goauthentik.io/docs/troubleshooting/missing_permission"
|
"https://goauthentik.io/docs/troubleshooting/missing_permission"
|
||||||
)
|
),
|
||||||
+ exception_to_string(exc),
|
).with_exception(exc).set_user(user).save()
|
||||||
).set_user(user).save()
|
|
||||||
else:
|
else:
|
||||||
app_label, perm = model_or_perm.split(".")
|
app_label, perm = model_or_perm.split(".")
|
||||||
permission = Permission.objects.filter(
|
permission = Permission.objects.filter(
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""authentik policy engine"""
|
"""authentik policy engine"""
|
||||||
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterable
|
||||||
from multiprocessing import Pipe, current_process
|
from multiprocessing import Pipe, current_process
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from time import perf_counter
|
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
from django.db.models import Count, Q, QuerySet
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from sentry_sdk import start_span
|
from sentry_sdk import start_span
|
||||||
from sentry_sdk.tracing import Span
|
from sentry_sdk.tracing import Span
|
||||||
@@ -67,14 +67,11 @@ class PolicyEngine:
|
|||||||
self.__processes: list[PolicyProcessInfo] = []
|
self.__processes: list[PolicyProcessInfo] = []
|
||||||
self.use_cache = True
|
self.use_cache = True
|
||||||
self.__expected_result_count = 0
|
self.__expected_result_count = 0
|
||||||
|
self.__static_result: PolicyResult | None = None
|
||||||
|
|
||||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
|
||||||
"""Make sure all Policies are their respective classes"""
|
"""Make sure all Policies are their respective classes"""
|
||||||
return (
|
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
|
||||||
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
|
|
||||||
.order_by("order")
|
|
||||||
.iterator()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_policy_type(self, binding: PolicyBinding):
|
def _check_policy_type(self, binding: PolicyBinding):
|
||||||
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
|
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
|
||||||
@@ -84,10 +81,17 @@ class PolicyEngine:
|
|||||||
def _check_cache(self, binding: PolicyBinding):
|
def _check_cache(self, binding: PolicyBinding):
|
||||||
if not self.use_cache:
|
if not self.use_cache:
|
||||||
return False
|
return False
|
||||||
before = perf_counter()
|
# It's a bit silly to time this, but
|
||||||
|
with HIST_POLICIES_EXECUTION_TIME.labels(
|
||||||
|
binding_order=binding.order,
|
||||||
|
binding_target_type=binding.target_type,
|
||||||
|
binding_target_name=binding.target_name,
|
||||||
|
object_pk=str(self.request.obj.pk),
|
||||||
|
object_type=class_to_path(self.request.obj.__class__),
|
||||||
|
mode="cache_retrieve",
|
||||||
|
).time():
|
||||||
key = cache_key(binding, self.request)
|
key = cache_key(binding, self.request)
|
||||||
cached_policy = cache.get(key, None)
|
cached_policy = cache.get(key, None)
|
||||||
duration = max(perf_counter() - before, 0)
|
|
||||||
if not cached_policy:
|
if not cached_policy:
|
||||||
return False
|
return False
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@@ -96,18 +100,47 @@ class PolicyEngine:
|
|||||||
cache_key=key,
|
cache_key=key,
|
||||||
request=self.request,
|
request=self.request,
|
||||||
)
|
)
|
||||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
|
||||||
binding_order=binding.order,
|
|
||||||
binding_target_type=binding.target_type,
|
|
||||||
binding_target_name=binding.target_name,
|
|
||||||
object_pk=str(self.request.obj.pk),
|
|
||||||
object_type=class_to_path(self.request.obj.__class__),
|
|
||||||
mode="cache_retrieve",
|
|
||||||
).observe(duration)
|
|
||||||
# It's a bit silly to time this, but
|
|
||||||
self.__cached_policies.append(cached_policy)
|
self.__cached_policies.append(cached_policy)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
||||||
|
"""Check static bindings if possible"""
|
||||||
|
aggrs = {
|
||||||
|
"total": Count(
|
||||||
|
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if self.request.user.pk:
|
||||||
|
all_groups = self.request.user.all_groups()
|
||||||
|
aggrs["passing"] = Count(
|
||||||
|
"pk",
|
||||||
|
filter=Q(
|
||||||
|
Q(
|
||||||
|
Q(user=self.request.user) | Q(group__in=all_groups),
|
||||||
|
negate=False,
|
||||||
|
)
|
||||||
|
| Q(
|
||||||
|
Q(~Q(user=self.request.user), user__isnull=False)
|
||||||
|
| Q(~Q(group__in=all_groups), group__isnull=False),
|
||||||
|
negate=True,
|
||||||
|
),
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
matched_bindings = bindings.aggregate(**aggrs)
|
||||||
|
passing = False
|
||||||
|
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
|
||||||
|
# If we didn't find any static bindings, do nothing
|
||||||
|
return
|
||||||
|
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||||
|
if matched_bindings.get("passing", 0) > 0:
|
||||||
|
# Any passing static binding -> passing
|
||||||
|
passing = True
|
||||||
|
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||||
|
# No matching static bindings but at least one is configured -> not passing
|
||||||
|
passing = False
|
||||||
|
self.__static_result = PolicyResult(passing)
|
||||||
|
|
||||||
def build(self) -> "PolicyEngine":
|
def build(self) -> "PolicyEngine":
|
||||||
"""Build wrapper which monitors performance"""
|
"""Build wrapper which monitors performance"""
|
||||||
with (
|
with (
|
||||||
@@ -123,7 +156,12 @@ class PolicyEngine:
|
|||||||
span: Span
|
span: Span
|
||||||
span.set_data("pbm", self.__pbm)
|
span.set_data("pbm", self.__pbm)
|
||||||
span.set_data("request", self.request)
|
span.set_data("request", self.request)
|
||||||
for binding in self.iterate_bindings():
|
bindings = self.bindings()
|
||||||
|
policy_bindings = bindings
|
||||||
|
if isinstance(bindings, QuerySet):
|
||||||
|
self.compute_static_bindings(bindings)
|
||||||
|
policy_bindings = [x for x in bindings if x.policy]
|
||||||
|
for binding in policy_bindings:
|
||||||
self.__expected_result_count += 1
|
self.__expected_result_count += 1
|
||||||
|
|
||||||
self._check_policy_type(binding)
|
self._check_policy_type(binding)
|
||||||
@@ -153,10 +191,13 @@ class PolicyEngine:
|
|||||||
@property
|
@property
|
||||||
def result(self) -> PolicyResult:
|
def result(self) -> PolicyResult:
|
||||||
"""Get policy-checking result"""
|
"""Get policy-checking result"""
|
||||||
|
self.__processes.sort(key=lambda x: x.binding.order)
|
||||||
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
|
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
|
||||||
all_results = list(process_results + self.__cached_policies)
|
all_results = list(process_results + self.__cached_policies)
|
||||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||||
raise AssertionError("Got less results than polices")
|
raise AssertionError("Got less results than polices")
|
||||||
|
if self.__static_result:
|
||||||
|
all_results.append(self.__static_result)
|
||||||
# No results, no policies attached -> passing
|
# No results, no policies attached -> passing
|
||||||
if len(all_results) == 0:
|
if len(all_results) == 0:
|
||||||
return PolicyResult(self.empty_result)
|
return PolicyResult(self.empty_result)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.events.models import Event
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||||
from authentik.policies.exceptions import PolicyException
|
from authentik.policies.exceptions import PolicyException
|
||||||
@@ -45,6 +46,10 @@ class PolicyEvaluator(BaseEvaluator):
|
|||||||
self.set_http_request(request.http_request)
|
self.set_http_request(request.http_request)
|
||||||
self._context["request"] = request
|
self._context["request"] = request
|
||||||
self._context["context"] = request.context
|
self._context["context"] = request.context
|
||||||
|
if request.obj and isinstance(request.obj, Event):
|
||||||
|
self._context["ak_client_ip"] = ip_address(
|
||||||
|
request.obj.client_ip or ClientIPMiddleware.default_ip
|
||||||
|
)
|
||||||
|
|
||||||
def set_http_request(self, request: HttpRequest):
|
def set_http_request(self, request: HttpRequest):
|
||||||
"""Update context based on http request"""
|
"""Update context based on http request"""
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
from authentik.lib.utils.errors import exception_to_dict
|
||||||
from authentik.lib.utils.reflection import class_to_path
|
from authentik.lib.utils.reflection import class_to_path
|
||||||
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
|
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
|
||||||
from authentik.policies.exceptions import PolicyException
|
from authentik.policies.exceptions import PolicyException
|
||||||
@@ -95,10 +95,13 @@ class PolicyProcess(PROCESS_CLASS):
|
|||||||
except PolicyException as exc:
|
except PolicyException as exc:
|
||||||
# Either use passed original exception or whatever we have
|
# Either use passed original exception or whatever we have
|
||||||
src_exc = exc.src_exc if exc.src_exc else exc
|
src_exc = exc.src_exc if exc.src_exc else exc
|
||||||
error_string = exception_to_string(src_exc)
|
|
||||||
# Create policy exception event, only when we're not debugging
|
# Create policy exception event, only when we're not debugging
|
||||||
if not self.request.debug:
|
if not self.request.debug:
|
||||||
self.create_event(EventAction.POLICY_EXCEPTION, message=error_string)
|
self.create_event(
|
||||||
|
EventAction.POLICY_EXCEPTION,
|
||||||
|
message="Policy failed to execute",
|
||||||
|
exception=exception_to_dict(src_exc),
|
||||||
|
)
|
||||||
LOGGER.debug("P_ENG(proc): error, using failure result", exc=src_exc)
|
LOGGER.debug("P_ENG(proc): error, using failure result", exc=src_exc)
|
||||||
policy_result = PolicyResult(self.binding.failure_result, str(src_exc))
|
policy_result = PolicyResult(self.binding.failure_result, str(src_exc))
|
||||||
policy_result.source_binding = self.binding
|
policy_result.source_binding = self.binding
|
||||||
@@ -143,5 +146,5 @@ class PolicyProcess(PROCESS_CLASS):
|
|||||||
try:
|
try:
|
||||||
self.connection.send(self.profiling_wrapper())
|
self.connection.send(self.profiling_wrapper())
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOGGER.warning("Policy failed to run", exc=exception_to_string(exc))
|
LOGGER.warning("Policy failed to run", exc=exc)
|
||||||
self.connection.send(PolicyResult(False, str(exc)))
|
self.connection.send(PolicyResult(False, str(exc)))
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
"""policy engine tests"""
|
"""policy engine tests"""
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
from django.db import connections
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.models import Group
|
||||||
|
from authentik.core.tests.utils import create_test_user
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.policies.dummy.models import DummyPolicy
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
@@ -19,7 +22,7 @@ class TestPolicyEngine(TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
clear_policy_cache()
|
clear_policy_cache()
|
||||||
self.user = create_test_admin_user()
|
self.user = create_test_user()
|
||||||
self.policy_false = DummyPolicy.objects.create(
|
self.policy_false = DummyPolicy.objects.create(
|
||||||
name=generate_id(), result=False, wait_min=0, wait_max=1
|
name=generate_id(), result=False, wait_min=0, wait_max=1
|
||||||
)
|
)
|
||||||
@@ -127,3 +130,58 @@ class TestPolicyEngine(TestCase):
|
|||||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||||
self.assertEqual(engine.build().passing, False)
|
self.assertEqual(engine.build().passing, False)
|
||||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||||
|
|
||||||
|
def test_engine_static_bindings(self):
|
||||||
|
"""Test static bindings"""
|
||||||
|
group_a = Group.objects.create(name=generate_id())
|
||||||
|
group_b = Group.objects.create(name=generate_id())
|
||||||
|
group_b.users.add(self.user)
|
||||||
|
user = create_test_user()
|
||||||
|
|
||||||
|
for case in [
|
||||||
|
{
|
||||||
|
"message": "Group, not member",
|
||||||
|
"binding_args": {"group": group_a},
|
||||||
|
"passing": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message": "Group, member",
|
||||||
|
"binding_args": {"group": group_b},
|
||||||
|
"passing": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message": "User, other",
|
||||||
|
"binding_args": {"user": user},
|
||||||
|
"passing": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message": "User, same",
|
||||||
|
"binding_args": {"user": self.user},
|
||||||
|
"passing": True,
|
||||||
|
},
|
||||||
|
]:
|
||||||
|
with self.subTest():
|
||||||
|
pbm = PolicyBindingModel.objects.create()
|
||||||
|
for x in range(1000):
|
||||||
|
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
|
||||||
|
engine = PolicyEngine(pbm, self.user)
|
||||||
|
engine.use_cache = False
|
||||||
|
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||||
|
engine.build()
|
||||||
|
self.assertLess(ctx.final_queries, 1000)
|
||||||
|
self.assertEqual(engine.result.passing, case["passing"])
|
||||||
|
|
||||||
|
def test_engine_group_complex(self):
|
||||||
|
"""Test more complex group setups"""
|
||||||
|
group_a = Group.objects.create(name=generate_id())
|
||||||
|
group_b = Group.objects.create(name=generate_id(), parent=group_a)
|
||||||
|
user = create_test_user()
|
||||||
|
group_b.users.add(user)
|
||||||
|
pbm = PolicyBindingModel.objects.create()
|
||||||
|
PolicyBinding.objects.create(target=pbm, order=0, group=group_a)
|
||||||
|
engine = PolicyEngine(pbm, user)
|
||||||
|
engine.use_cache = False
|
||||||
|
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||||
|
engine.build()
|
||||||
|
self.assertLess(ctx.final_queries, 1000)
|
||||||
|
self.assertTrue(engine.result.passing)
|
||||||
|
|||||||
@@ -29,13 +29,12 @@ class TestPolicyProcess(TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
clear_policy_cache()
|
clear_policy_cache()
|
||||||
self.factory = RequestFactory()
|
self.factory = RequestFactory()
|
||||||
self.user = User.objects.create_user(username="policyuser")
|
self.user = User.objects.create_user(username=generate_id())
|
||||||
|
|
||||||
def test_group_passing(self):
|
def test_group_passing(self):
|
||||||
"""Test binding to group"""
|
"""Test binding to group"""
|
||||||
group = Group.objects.create(name="test-group")
|
group = Group.objects.create(name=generate_id())
|
||||||
group.users.add(self.user)
|
group.users.add(self.user)
|
||||||
group.save()
|
|
||||||
binding = PolicyBinding(group=group)
|
binding = PolicyBinding(group=group)
|
||||||
|
|
||||||
request = PolicyRequest(self.user)
|
request = PolicyRequest(self.user)
|
||||||
@@ -44,8 +43,7 @@ class TestPolicyProcess(TestCase):
|
|||||||
|
|
||||||
def test_group_negative(self):
|
def test_group_negative(self):
|
||||||
"""Test binding to group"""
|
"""Test binding to group"""
|
||||||
group = Group.objects.create(name="test-group")
|
group = Group.objects.create(name=generate_id())
|
||||||
group.save()
|
|
||||||
binding = PolicyBinding(group=group)
|
binding = PolicyBinding(group=group)
|
||||||
|
|
||||||
request = PolicyRequest(self.user)
|
request = PolicyRequest(self.user)
|
||||||
@@ -115,8 +113,10 @@ class TestPolicyProcess(TestCase):
|
|||||||
|
|
||||||
def test_exception(self):
|
def test_exception(self):
|
||||||
"""Test policy execution"""
|
"""Test policy execution"""
|
||||||
policy = Policy.objects.create(name="test-execution")
|
policy = Policy.objects.create(name=generate_id())
|
||||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
binding = PolicyBinding(
|
||||||
|
policy=policy, target=Application.objects.create(name=generate_id())
|
||||||
|
)
|
||||||
|
|
||||||
request = PolicyRequest(self.user)
|
request = PolicyRequest(self.user)
|
||||||
response = PolicyProcess(binding, request, None).execute()
|
response = PolicyProcess(binding, request, None).execute()
|
||||||
@@ -125,13 +125,15 @@ class TestPolicyProcess(TestCase):
|
|||||||
def test_execution_logging(self):
|
def test_execution_logging(self):
|
||||||
"""Test policy execution creates event"""
|
"""Test policy execution creates event"""
|
||||||
policy = DummyPolicy.objects.create(
|
policy = DummyPolicy.objects.create(
|
||||||
name="test-execution-logging",
|
name=generate_id(),
|
||||||
result=False,
|
result=False,
|
||||||
wait_min=0,
|
wait_min=0,
|
||||||
wait_max=1,
|
wait_max=1,
|
||||||
execution_logging=True,
|
execution_logging=True,
|
||||||
)
|
)
|
||||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
binding = PolicyBinding(
|
||||||
|
policy=policy, target=Application.objects.create(name=generate_id())
|
||||||
|
)
|
||||||
|
|
||||||
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
|
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
|
||||||
http_request.user = self.user
|
http_request.user = self.user
|
||||||
@@ -186,13 +188,15 @@ class TestPolicyProcess(TestCase):
|
|||||||
def test_execution_logging_anonymous(self):
|
def test_execution_logging_anonymous(self):
|
||||||
"""Test policy execution creates event with anonymous user"""
|
"""Test policy execution creates event with anonymous user"""
|
||||||
policy = DummyPolicy.objects.create(
|
policy = DummyPolicy.objects.create(
|
||||||
name="test-execution-logging-anon",
|
name=generate_id(),
|
||||||
result=False,
|
result=False,
|
||||||
wait_min=0,
|
wait_min=0,
|
||||||
wait_max=1,
|
wait_max=1,
|
||||||
execution_logging=True,
|
execution_logging=True,
|
||||||
)
|
)
|
||||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
binding = PolicyBinding(
|
||||||
|
policy=policy, target=Application.objects.create(name=generate_id())
|
||||||
|
)
|
||||||
|
|
||||||
user = AnonymousUser()
|
user = AnonymousUser()
|
||||||
|
|
||||||
@@ -219,9 +223,9 @@ class TestPolicyProcess(TestCase):
|
|||||||
|
|
||||||
def test_raises(self):
|
def test_raises(self):
|
||||||
"""Test policy that raises error"""
|
"""Test policy that raises error"""
|
||||||
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
|
policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
|
||||||
binding = PolicyBinding(
|
binding = PolicyBinding(
|
||||||
policy=policy_raises, target=Application.objects.create(name="test")
|
policy=policy_raises, target=Application.objects.create(name=generate_id())
|
||||||
)
|
)
|
||||||
|
|
||||||
request = PolicyRequest(self.user)
|
request = PolicyRequest(self.user)
|
||||||
@@ -237,4 +241,4 @@ class TestPolicyProcess(TestCase):
|
|||||||
self.assertEqual(len(events), 1)
|
self.assertEqual(len(events), 1)
|
||||||
event = events.first()
|
event = events.first()
|
||||||
self.assertEqual(event.user["username"], self.user.username)
|
self.assertEqual(event.user["username"], self.user.username)
|
||||||
self.assertIn("division by zero", event.context["message"])
|
self.assertIn("Policy failed to execute", event.context["message"])
|
||||||
|
|||||||
@@ -15,12 +15,14 @@ class OAuth2Error(SentryIgnoredException):
|
|||||||
|
|
||||||
error: str
|
error: str
|
||||||
description: str
|
description: str
|
||||||
|
cause: str | None = None
|
||||||
|
|
||||||
def create_dict(self):
|
def create_dict(self, request: HttpRequest):
|
||||||
"""Return error as dict for JSON Rendering"""
|
"""Return error as dict for JSON Rendering"""
|
||||||
return {
|
return {
|
||||||
"error": self.error,
|
"error": self.error,
|
||||||
"error_description": self.description,
|
"error_description": self.description,
|
||||||
|
"request_id": request.request_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -31,9 +33,15 @@ class OAuth2Error(SentryIgnoredException):
|
|||||||
return Event.new(
|
return Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message=message or self.description,
|
message=message or self.description,
|
||||||
|
cause=self.cause,
|
||||||
|
error=self.error,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def with_cause(self, cause: str):
|
||||||
|
self.cause = cause
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RedirectUriError(OAuth2Error):
|
class RedirectUriError(OAuth2Error):
|
||||||
"""The request fails due to a missing, invalid, or mismatching
|
"""The request fails due to a missing, invalid, or mismatching
|
||||||
@@ -243,13 +251,14 @@ class TokenRevocationError(OAuth2Error):
|
|||||||
self.description = self.errors[error]
|
self.description = self.errors[error]
|
||||||
|
|
||||||
|
|
||||||
class DeviceCodeError(OAuth2Error):
|
class DeviceCodeError(TokenError):
|
||||||
"""
|
"""
|
||||||
Device-code flow errors
|
Device-code flow errors
|
||||||
See https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
See https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||||
|
Can also use codes form TokenError
|
||||||
"""
|
"""
|
||||||
|
|
||||||
errors = {
|
errors = TokenError.errors | {
|
||||||
"authorization_pending": (
|
"authorization_pending": (
|
||||||
"The authorization request is still pending as the end user hasn't "
|
"The authorization request is still pending as the end user hasn't "
|
||||||
"yet completed the user-interaction steps"
|
"yet completed the user-interaction steps"
|
||||||
@@ -261,10 +270,15 @@ class DeviceCodeError(OAuth2Error):
|
|||||||
"authorization request but SHOULD wait for user interaction before "
|
"authorization request but SHOULD wait for user interaction before "
|
||||||
"restarting to avoid unnecessary polling."
|
"restarting to avoid unnecessary polling."
|
||||||
),
|
),
|
||||||
|
"slow_down": (
|
||||||
|
'A variant of "authorization_pending", the authorization request is'
|
||||||
|
"still pending and polling should continue, but the interval MUST"
|
||||||
|
"be increased by 5 seconds for this and all subsequent requests."
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, error: str):
|
def __init__(self, error: str):
|
||||||
super().__init__()
|
super().__init__(error)
|
||||||
self.error = error
|
self.error = error
|
||||||
self.description = self.errors[error]
|
self.description = self.errors[error]
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
|||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.providers.oauth2.constants import TOKEN_TYPE
|
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
|
||||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||||
from authentik.providers.oauth2.models import (
|
from authentik.providers.oauth2.models import (
|
||||||
AccessToken,
|
AccessToken,
|
||||||
@@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||||
|
|
||||||
def test_invalid_client_id(self):
|
def test_invalid_client_id(self):
|
||||||
"""Test invalid client ID"""
|
"""Test invalid client ID"""
|
||||||
@@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.error, "request_not_supported")
|
||||||
|
|
||||||
def test_invalid_redirect_uri(self):
|
def test_invalid_redirect_uri_missing(self):
|
||||||
"""test missing/invalid redirect URI"""
|
"""test missing redirect URI"""
|
||||||
OAuth2Provider.objects.create(
|
OAuth2Provider.objects.create(
|
||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
with self.assertRaises(RedirectUriError):
|
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
|
||||||
|
|
||||||
|
def test_invalid_redirect_uri(self):
|
||||||
|
"""test invalid redirect URI"""
|
||||||
|
OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
client_id="test",
|
||||||
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||||
|
)
|
||||||
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||||
|
|
||||||
def test_blocked_redirect_uri(self):
|
def test_blocked_redirect_uri(self):
|
||||||
"""test missing/invalid redirect URI"""
|
"""test missing/invalid redirect URI"""
|
||||||
@@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||||
|
|
||||||
def test_invalid_redirect_uri_empty(self):
|
def test_invalid_redirect_uri_empty(self):
|
||||||
"""test missing/invalid redirect URI"""
|
"""test missing/invalid redirect URI"""
|
||||||
@@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[],
|
redirect_uris=[],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||||
|
|
||||||
def test_redirect_uri_invalid_regex(self):
|
def test_redirect_uri_invalid_regex(self):
|
||||||
"""test missing/invalid redirect URI (invalid regex)"""
|
"""test missing/invalid redirect URI (invalid regex)"""
|
||||||
@@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
with self.assertRaises(RedirectUriError) as cm:
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||||
|
|
||||||
def test_empty_redirect_uri(self):
|
def test_redirect_uri_regex(self):
|
||||||
"""test empty redirect URI (configure in provider)"""
|
"""test valid redirect URI (regex)"""
|
||||||
OAuth2Provider.objects.create(
|
OAuth2Provider.objects.create(
|
||||||
name=generate_id(),
|
name=generate_id(),
|
||||||
client_id="test",
|
client_id="test",
|
||||||
authorization_flow=create_test_flow(),
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
|
||||||
)
|
)
|
||||||
with self.assertRaises(RedirectUriError):
|
|
||||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
|
||||||
OAuthAuthorizationParams.from_request(request)
|
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"client_id": "test",
|
"client_id": "test",
|
||||||
"redirect_uri": "http://localhost",
|
"redirect_uri": "http://foo.bar.baz",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
@@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
GrantTypes.IMPLICIT,
|
GrantTypes.IMPLICIT,
|
||||||
)
|
)
|
||||||
# Implicit without openid scope
|
# Implicit without openid scope
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthorizeError):
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
OAuthAuthorizationParams.from_request(request)
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||||
|
|
||||||
def test_full_code(self):
|
def test_full_code(self):
|
||||||
"""Test full authorization"""
|
"""Test full authorization"""
|
||||||
@@ -613,3 +619,54 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_openid_missing_invalid(self):
|
||||||
|
"""test request requiring an OpenID scope to be set"""
|
||||||
|
OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
client_id="test",
|
||||||
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||||
|
)
|
||||||
|
request = self.factory.get(
|
||||||
|
"/",
|
||||||
|
data={
|
||||||
|
"response_type": "id_token",
|
||||||
|
"client_id": "test",
|
||||||
|
"redirect_uri": "http://localhost",
|
||||||
|
"scope": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with self.assertRaises(AuthorizeError) as cm:
|
||||||
|
OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertEqual(cm.exception.cause, "scope_openid_missing")
|
||||||
|
|
||||||
|
@apply_blueprint("system/providers-oauth2.yaml")
|
||||||
|
def test_offline_access_invalid(self):
|
||||||
|
"""test request for offline_access with invalid response type"""
|
||||||
|
provider = OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
client_id="test",
|
||||||
|
authorization_flow=create_test_flow(),
|
||||||
|
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||||
|
)
|
||||||
|
provider.property_mappings.set(
|
||||||
|
ScopeMapping.objects.filter(
|
||||||
|
managed__in=[
|
||||||
|
"goauthentik.io/providers/oauth2/scope-openid",
|
||||||
|
"goauthentik.io/providers/oauth2/scope-offline_access",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request = self.factory.get(
|
||||||
|
"/",
|
||||||
|
data={
|
||||||
|
"response_type": "id_token",
|
||||||
|
"client_id": "test",
|
||||||
|
"redirect_uri": "http://localhost",
|
||||||
|
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
|
||||||
|
"nonce": generate_id(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
parsed = OAuthAuthorizationParams.from_request(request)
|
||||||
|
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
|
||||||
|
|||||||
@@ -68,7 +68,11 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_provider(self):
|
def test_no_provider(self):
|
||||||
@@ -87,7 +91,11 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_permission_denied(self):
|
def test_permission_denied(self):
|
||||||
@@ -110,7 +118,11 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_incorrect_scopes(self):
|
def test_incorrect_scopes(self):
|
||||||
|
|||||||
@@ -68,7 +68,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_wrong_token(self):
|
def test_wrong_token(self):
|
||||||
@@ -85,7 +89,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_provider(self):
|
def test_no_provider(self):
|
||||||
@@ -104,7 +112,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_permission_denied(self):
|
def test_permission_denied(self):
|
||||||
@@ -127,7 +139,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_successful(self):
|
def test_successful(self):
|
||||||
|
|||||||
@@ -68,7 +68,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_wrong_token(self):
|
def test_wrong_token(self):
|
||||||
@@ -86,7 +90,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_provider(self):
|
def test_no_provider(self):
|
||||||
@@ -106,7 +114,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_permission_denied(self):
|
def test_permission_denied(self):
|
||||||
@@ -130,7 +142,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
|||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": TokenError.errors["invalid_grant"],
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_successful(self):
|
def test_successful(self):
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class TestTokenPKCE(OAuthTestCase):
|
|||||||
"revoked, does not match the redirection URI used in the authorization "
|
"revoked, does not match the redirection URI used in the authorization "
|
||||||
"request, or was issued to another client"
|
"request, or was issued to another client"
|
||||||
),
|
),
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
@@ -136,6 +137,7 @@ class TestTokenPKCE(OAuthTestCase):
|
|||||||
"revoked, does not match the redirection URI used in the authorization "
|
"revoked, does not match the redirection URI used in the authorization "
|
||||||
"request, or was issued to another client"
|
"request, or was issued to another client"
|
||||||
),
|
),
|
||||||
|
"request_id": response.headers["X-authentik-id"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 400)
|
self.assertEqual(response.status_code, 400)
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
|
|||||||
allowed_redirect_urls = self.provider.redirect_uris
|
allowed_redirect_urls = self.provider.redirect_uris
|
||||||
if not self.redirect_uri:
|
if not self.redirect_uri:
|
||||||
LOGGER.warning("Missing redirect uri.")
|
LOGGER.warning("Missing redirect uri.")
|
||||||
raise RedirectUriError("", allowed_redirect_urls)
|
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||||
|
|
||||||
if len(allowed_redirect_urls) < 1:
|
if len(allowed_redirect_urls) < 1:
|
||||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||||
@@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
)
|
)
|
||||||
if not match_found:
|
if not match_found:
|
||||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||||
|
"redirect_uri_no_match"
|
||||||
|
)
|
||||||
# Check against forbidden schemes
|
# Check against forbidden schemes
|
||||||
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
||||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||||
|
"redirect_uri_forbidden_scheme"
|
||||||
|
)
|
||||||
|
|
||||||
def check_scope(self, github_compat=False):
|
def check_scope(self, github_compat=False):
|
||||||
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
||||||
@@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
|
|||||||
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||||
):
|
):
|
||||||
LOGGER.warning("Missing 'openid' scope.")
|
LOGGER.warning("Missing 'openid' scope.")
|
||||||
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
|
raise AuthorizeError(
|
||||||
|
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
||||||
|
).with_cause("scope_openid_missing")
|
||||||
if SCOPE_OFFLINE_ACCESS in self.scope:
|
if SCOPE_OFFLINE_ACCESS in self.scope:
|
||||||
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||||
# Don't explicitly request consent with offline_access, as the spec allows for
|
# Don't explicitly request consent with offline_access, as the spec allows for
|
||||||
@@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
|
|||||||
return
|
return
|
||||||
if not self.nonce:
|
if not self.nonce:
|
||||||
LOGGER.warning("Missing nonce for OpenID Request")
|
LOGGER.warning("Missing nonce for OpenID Request")
|
||||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
raise AuthorizeError(
|
||||||
|
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||||
|
).with_cause("nonce_missing")
|
||||||
|
|
||||||
def check_code_challenge(self):
|
def check_code_challenge(self):
|
||||||
"""PKCE validation of the transformation method."""
|
"""PKCE validation of the transformation method."""
|
||||||
@@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
|||||||
self.request, github_compat=self.github_compat
|
self.request, github_compat=self.github_compat
|
||||||
)
|
)
|
||||||
except AuthorizeError as error:
|
except AuthorizeError as error:
|
||||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
|
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
|
||||||
raise RequestValidationError(error.get_response(self.request)) from None
|
raise RequestValidationError(error.get_response(self.request)) from None
|
||||||
except OAuth2Error as error:
|
except OAuth2Error as error:
|
||||||
LOGGER.warning(error.description)
|
LOGGER.warning(error.description, cause=error.cause)
|
||||||
raise RequestValidationError(
|
raise RequestValidationError(
|
||||||
bad_request_message(self.request, error.description, title=error.error)
|
bad_request_message(self.request, error.description, title=error.error)
|
||||||
) from None
|
) from None
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, JsonResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
@@ -14,7 +14,9 @@ from structlog.stdlib import get_logger
|
|||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
|
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||||
|
from authentik.providers.oauth2.utils import TokenResponse
|
||||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
@@ -28,38 +30,36 @@ class DeviceView(View):
|
|||||||
provider: OAuth2Provider
|
provider: OAuth2Provider
|
||||||
scopes: list[str] = []
|
scopes: list[str] = []
|
||||||
|
|
||||||
def parse_request(self) -> HttpResponse | None:
|
def parse_request(self):
|
||||||
"""Parse incoming request"""
|
"""Parse incoming request"""
|
||||||
client_id = self.request.POST.get("client_id", None)
|
client_id = self.request.POST.get("client_id", None)
|
||||||
if not client_id:
|
if not client_id:
|
||||||
return HttpResponseBadRequest()
|
raise DeviceCodeError("invalid_client")
|
||||||
provider = OAuth2Provider.objects.filter(
|
provider = OAuth2Provider.objects.filter(client_id=client_id).first()
|
||||||
client_id=client_id,
|
|
||||||
).first()
|
|
||||||
if not provider:
|
if not provider:
|
||||||
return HttpResponseBadRequest()
|
raise DeviceCodeError("invalid_client")
|
||||||
try:
|
try:
|
||||||
_ = provider.application
|
_ = provider.application
|
||||||
except Application.DoesNotExist:
|
except Application.DoesNotExist:
|
||||||
return HttpResponseBadRequest()
|
raise DeviceCodeError("invalid_client") from None
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.scopes = self.request.POST.get("scope", "").split(" ")
|
self.scopes = self.request.POST.get("scope", "").split(" ")
|
||||||
return None
|
|
||||||
|
|
||||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
throttle = AnonRateThrottle()
|
throttle = AnonRateThrottle()
|
||||||
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
||||||
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
||||||
if not throttle.allow_request(request, self):
|
if not throttle.allow_request(request, self):
|
||||||
return HttpResponse(status=429)
|
return TokenResponse(DeviceCodeError("slow_down").create_dict(request), status=429)
|
||||||
return super().dispatch(request, *args, **kwargs)
|
return super().dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
def post(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Generate device token"""
|
"""Generate device token"""
|
||||||
resp = self.parse_request()
|
try:
|
||||||
if resp:
|
self.parse_request()
|
||||||
return resp
|
except DeviceCodeError as exc:
|
||||||
|
return TokenResponse(exc.create_dict(request), status=400)
|
||||||
until = timedelta_from_string(self.provider.access_code_validity)
|
until = timedelta_from_string(self.provider.access_code_validity)
|
||||||
token: DeviceToken = DeviceToken.objects.create(
|
token: DeviceToken = DeviceToken.objects.create(
|
||||||
expires=now() + until, provider=self.provider, _scope=" ".join(self.scopes)
|
expires=now() + until, provider=self.provider, _scope=" ".join(self.scopes)
|
||||||
@@ -67,7 +67,7 @@ class DeviceView(View):
|
|||||||
device_url = self.request.build_absolute_uri(
|
device_url = self.request.build_absolute_uri(
|
||||||
reverse("authentik_providers_oauth2_root:device-login")
|
reverse("authentik_providers_oauth2_root:device-login")
|
||||||
)
|
)
|
||||||
return JsonResponse(
|
return TokenResponse(
|
||||||
{
|
{
|
||||||
"device_code": token.device_code,
|
"device_code": token.device_code,
|
||||||
"verification_uri": device_url,
|
"verification_uri": device_url,
|
||||||
|
|||||||
@@ -598,9 +598,9 @@ class TokenView(View):
|
|||||||
return TokenResponse(self.create_device_code_response())
|
return TokenResponse(self.create_device_code_response())
|
||||||
raise TokenError("unsupported_grant_type")
|
raise TokenError("unsupported_grant_type")
|
||||||
except (TokenError, DeviceCodeError) as error:
|
except (TokenError, DeviceCodeError) as error:
|
||||||
return TokenResponse(error.create_dict(), status=400)
|
return TokenResponse(error.create_dict(request), status=400)
|
||||||
except UserAuthError as error:
|
except UserAuthError as error:
|
||||||
return TokenResponse(error.create_dict(), status=403)
|
return TokenResponse(error.create_dict(request), status=403)
|
||||||
|
|
||||||
def create_code_response(self) -> dict[str, Any]:
|
def create_code_response(self) -> dict[str, Any]:
|
||||||
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1"""
|
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1"""
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class TokenRevokeView(View):
|
|||||||
|
|
||||||
return TokenResponse(data={}, status=200)
|
return TokenResponse(data={}, status=200)
|
||||||
except TokenRevocationError as exc:
|
except TokenRevocationError as exc:
|
||||||
return TokenResponse(exc.create_dict(), status=401)
|
return TokenResponse(exc.create_dict(request), status=401)
|
||||||
except Http404:
|
except Http404:
|
||||||
# Token not found should return a HTTP 200
|
# Token not found should return a HTTP 200
|
||||||
# https://datatracker.ietf.org/doc/html/rfc7009#section-2.2
|
# https://datatracker.ietf.org/doc/html/rfc7009#section-2.2
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
|
|||||||
# Buffer sizes for large headers with JWTs
|
# Buffer sizes for large headers with JWTs
|
||||||
"nginx.ingress.kubernetes.io/proxy-buffers-number": "4",
|
"nginx.ingress.kubernetes.io/proxy-buffers-number": "4",
|
||||||
"nginx.ingress.kubernetes.io/proxy-buffer-size": "16k",
|
"nginx.ingress.kubernetes.io/proxy-buffer-size": "16k",
|
||||||
|
"nginx.ingress.kubernetes.io/proxy-busy-buffers-size": "32k",
|
||||||
# Enable TLS in traefik
|
# Enable TLS in traefik
|
||||||
"traefik.ingress.kubernetes.io/router.tls": "true",
|
"traefik.ingress.kubernetes.io/router.tls": "true",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from authentik.core.models import Application
|
|||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.expression.exceptions import ControlFlowException
|
from authentik.lib.expression.exceptions import ControlFlowException
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.policies.types import PolicyResult
|
from authentik.policies.types import PolicyResult
|
||||||
@@ -142,9 +141,9 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
|||||||
# Value error can be raised when assigning invalid data to an attribute
|
# Value error can be raised when assigning invalid data to an attribute
|
||||||
Event.new(
|
Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
message="Failed to evaluate property-mapping",
|
||||||
mapping=exc.mapping,
|
mapping=exc.mapping,
|
||||||
).save()
|
).with_exception(exc).save()
|
||||||
return None
|
return None
|
||||||
return b64encode(packet.RequestPacket()).decode()
|
return b64encode(packet.RequestPacket()).decode()
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
||||||
from pydanticscim.group import Group as BaseGroup
|
from pydanticscim.group import Group as BaseGroup
|
||||||
from pydanticscim.responses import PatchOperation as BasePatchOperation
|
from pydanticscim.responses import PatchOperation as BasePatchOperation
|
||||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||||
@@ -12,19 +12,95 @@ from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
|||||||
from pydanticscim.service_provider import (
|
from pydanticscim.service_provider import (
|
||||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||||
)
|
)
|
||||||
|
from pydanticscim.user import AddressKind
|
||||||
from pydanticscim.user import User as BaseUser
|
from pydanticscim.user import User as BaseUser
|
||||||
|
|
||||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||||
|
|
||||||
|
|
||||||
|
class Address(BaseModel):
|
||||||
|
formatted: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="The full mailing address, formatted for display "
|
||||||
|
"or use with a mailing label. This attribute MAY contain newlines.",
|
||||||
|
)
|
||||||
|
streetAddress: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="The full street address component, which may "
|
||||||
|
"include house number, street name, P.O. box, and multi-line "
|
||||||
|
"extended street address information. This attribute MAY contain newlines.",
|
||||||
|
)
|
||||||
|
locality: str | None = Field(None, description="The city or locality component.")
|
||||||
|
region: str | None = Field(None, description="The state or region component.")
|
||||||
|
postalCode: str | None = Field(None, description="The zip code or postal code component.")
|
||||||
|
country: str | None = Field(None, description="The country name component.")
|
||||||
|
type: AddressKind | None = Field(
|
||||||
|
None,
|
||||||
|
description="A label indicating the attribute's function, e.g., 'work' or 'home'.",
|
||||||
|
)
|
||||||
|
primary: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Manager(BaseModel):
|
||||||
|
value: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="The id of the SCIM resource representingthe User's manager. REQUIRED.",
|
||||||
|
)
|
||||||
|
ref: AnyUrl | None = Field(
|
||||||
|
None,
|
||||||
|
alias="$ref",
|
||||||
|
description="The URI of the SCIM resource representing the User's manager. REQUIRED.",
|
||||||
|
)
|
||||||
|
displayName: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="The displayName of the User's manager. OPTIONAL and READ-ONLY.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseUser(BaseModel):
|
||||||
|
employeeNumber: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Numeric or alphanumeric identifier assigned to a person, "
|
||||||
|
"typically based on order of hire or association with anorganization.",
|
||||||
|
)
|
||||||
|
costCenter: str | None = Field(None, description="Identifies the name of a cost center.")
|
||||||
|
organization: str | None = Field(None, description="Identifies the name of an organization.")
|
||||||
|
division: str | None = Field(None, description="Identifies the name of a division.")
|
||||||
|
department: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Numeric or alphanumeric identifier assigned to a person,"
|
||||||
|
" typically based on order of hire or association with anorganization.",
|
||||||
|
)
|
||||||
|
manager: Manager | None = Field(
|
||||||
|
None,
|
||||||
|
description="The User's manager. A complex type that optionally allows "
|
||||||
|
"service providers to represent organizational hierarchy by referencing"
|
||||||
|
" the 'id' attribute of another User.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class User(BaseUser):
|
class User(BaseUser):
|
||||||
"""Modified User schema with added externalId field"""
|
"""Modified User schema with added externalId field"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(serialize_by_alias=True)
|
||||||
|
|
||||||
id: str | int | None = None
|
id: str | int | None = None
|
||||||
schemas: list[str] = [SCIM_USER_SCHEMA]
|
schemas: list[str] = [SCIM_USER_SCHEMA]
|
||||||
externalId: str | None = None
|
externalId: str | None = None
|
||||||
meta: dict | None = None
|
meta: dict | None = None
|
||||||
|
addresses: list[Address] | None = Field(
|
||||||
|
None,
|
||||||
|
description=(
|
||||||
|
"A physical mailing address for this User. Canonical type "
|
||||||
|
"values of 'work', 'home', and 'other'."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
enterprise_user: EnterpriseUser | None = Field(
|
||||||
|
default=None,
|
||||||
|
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||||
|
serialization_alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Group(BaseGroup):
|
class Group(BaseGroup):
|
||||||
@@ -92,7 +168,7 @@ class PatchOperation(BasePatchOperation):
|
|||||||
"""PatchOperation with optional path"""
|
"""PatchOperation with optional path"""
|
||||||
|
|
||||||
op: PatchOp
|
op: PatchOp
|
||||||
path: str | None
|
path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SCIMError(BaseSCIMError):
|
class SCIMError(BaseSCIMError):
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
|
|||||||
|
|
||||||
from authentik import get_full_version
|
from authentik import get_full_version
|
||||||
from authentik.lib.sentry import should_ignore_exception
|
from authentik.lib.sentry import should_ignore_exception
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
|
|
||||||
# set the default Django settings module for the 'celery' program.
|
# set the default Django settings module for the 'celery' program.
|
||||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||||
@@ -83,8 +82,8 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
|
|||||||
CTX_TASK_ID.set(...)
|
CTX_TASK_ID.set(...)
|
||||||
if not should_ignore_exception(exception):
|
if not should_ignore_exception(exception):
|
||||||
Event.new(
|
Event.new(
|
||||||
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
|
EventAction.SYSTEM_EXCEPTION, message="Failed to execute task", task_id=task_id
|
||||||
).save()
|
).with_exception(exception).save()
|
||||||
|
|
||||||
|
|
||||||
def _get_startup_tasks_default_tenant() -> list[Callable]:
|
def _get_startup_tasks_default_tenant() -> list[Callable]:
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ class ReadyView(View):
|
|||||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
try:
|
try:
|
||||||
for db_conn in connections.all():
|
for db_conn in connections.all():
|
||||||
|
# Force connection reload
|
||||||
|
db_conn.connect()
|
||||||
_ = db_conn.cursor()
|
_ = db_conn.cursor()
|
||||||
except OperationalError: # pragma: no cover
|
except OperationalError: # pragma: no cover
|
||||||
return HttpResponse(status=503)
|
return HttpResponse(status=503)
|
||||||
|
|||||||
@@ -156,16 +156,17 @@ SPECTACULAR_SETTINGS = {
|
|||||||
},
|
},
|
||||||
"ENUM_NAME_OVERRIDES": {
|
"ENUM_NAME_OVERRIDES": {
|
||||||
"CountryCodeEnum": "django_countries.countries",
|
"CountryCodeEnum": "django_countries.countries",
|
||||||
|
"DeviceClassesEnum": "authentik.stages.authenticator_validate.models.DeviceClasses",
|
||||||
"EventActions": "authentik.events.models.EventAction",
|
"EventActions": "authentik.events.models.EventAction",
|
||||||
"FlowDesignationEnum": "authentik.flows.models.FlowDesignation",
|
"FlowDesignationEnum": "authentik.flows.models.FlowDesignation",
|
||||||
"FlowLayoutEnum": "authentik.flows.models.FlowLayout",
|
"FlowLayoutEnum": "authentik.flows.models.FlowLayout",
|
||||||
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
|
||||||
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
|
||||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
|
||||||
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
||||||
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
|
||||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
|
||||||
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
|
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
|
||||||
|
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
||||||
|
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||||
|
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
||||||
|
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||||
|
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
||||||
},
|
},
|
||||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||||
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
|
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from pathlib import Path
|
|||||||
from secrets import token_urlsafe
|
from secrets import token_urlsafe
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TransactionTestCase
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
|
|
||||||
class TestRoot(TestCase):
|
class TestRoot(TransactionTestCase):
|
||||||
"""Test root application"""
|
"""Test root application"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from authentik.events.models import TaskStatus
|
|||||||
from authentik.events.system_tasks import SystemTask
|
from authentik.events.system_tasks import SystemTask
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
from authentik.sources.kerberos.models import KerberosSource
|
from authentik.sources.kerberos.models import KerberosSource
|
||||||
from authentik.sources.kerberos.sync import KerberosSync
|
from authentik.sources.kerberos.sync import KerberosSync
|
||||||
@@ -64,5 +63,5 @@ def kerberos_sync_single(self, source_pk: str):
|
|||||||
syncer.sync()
|
syncer.sync()
|
||||||
self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages)
|
self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages)
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
LOGGER.warning(exception_to_string(exc))
|
LOGGER.warning("Error syncing kerberos", exc=exc, source=source)
|
||||||
self.set_error(exc)
|
self.set_error(exc)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from authentik.events.models import TaskStatus
|
|||||||
from authentik.events.system_tasks import SystemTask
|
from authentik.events.system_tasks import SystemTask
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
from authentik.sources.ldap.models import LDAPSource
|
from authentik.sources.ldap.models import LDAPSource
|
||||||
@@ -149,5 +148,5 @@ def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key:
|
|||||||
cache.delete(page_cache_key)
|
cache.delete(page_cache_key)
|
||||||
except (LDAPException, StopSync) as exc:
|
except (LDAPException, StopSync) as exc:
|
||||||
# No explicit event is created here as .set_status with an error will do that
|
# No explicit event is created here as .set_status with an error will do that
|
||||||
LOGGER.warning(exception_to_string(exc))
|
LOGGER.warning("Failed to sync LDAP", exc=exc, source=source)
|
||||||
self.set_error(exc)
|
self.set_error(exc)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ AUTHENTIK_SOURCES_OAUTH_TYPES = [
|
|||||||
"authentik.sources.oauth.types.apple",
|
"authentik.sources.oauth.types.apple",
|
||||||
"authentik.sources.oauth.types.azure_ad",
|
"authentik.sources.oauth.types.azure_ad",
|
||||||
"authentik.sources.oauth.types.discord",
|
"authentik.sources.oauth.types.discord",
|
||||||
|
"authentik.sources.oauth.types.entra_id",
|
||||||
"authentik.sources.oauth.types.facebook",
|
"authentik.sources.oauth.types.facebook",
|
||||||
"authentik.sources.oauth.types.github",
|
"authentik.sources.oauth.types.github",
|
||||||
"authentik.sources.oauth.types.gitlab",
|
"authentik.sources.oauth.types.gitlab",
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ class GoogleOAuthSource(CreatableType, OAuthSource):
|
|||||||
|
|
||||||
|
|
||||||
class AzureADOAuthSource(CreatableType, OAuthSource):
|
class AzureADOAuthSource(CreatableType, OAuthSource):
|
||||||
"""Social Login using Azure AD."""
|
"""(Deprecated) Social Login using Azure AD."""
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
abstract = True
|
abstract = True
|
||||||
@@ -240,6 +240,17 @@ class AzureADOAuthSource(CreatableType, OAuthSource):
|
|||||||
verbose_name_plural = _("Azure AD OAuth Sources")
|
verbose_name_plural = _("Azure AD OAuth Sources")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: When removing this, add a migration for OAuthSource that sets
|
||||||
|
# provider_type to `entraid` if it is currently `azuread`
|
||||||
|
class EntraIDOAuthSource(CreatableType, OAuthSource):
|
||||||
|
"""Social Login using Entra ID."""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
verbose_name = _("Entra ID OAuth Source")
|
||||||
|
verbose_name_plural = _("Entra ID OAuth Sources")
|
||||||
|
|
||||||
|
|
||||||
class OpenIDConnectOAuthSource(CreatableType, OAuthSource):
|
class OpenIDConnectOAuthSource(CreatableType, OAuthSource):
|
||||||
"""Login using a Generic OpenID-Connect compliant provider."""
|
"""Login using a Generic OpenID-Connect compliant provider."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
"""azure ad Type tests"""
|
"""Entra ID Type tests"""
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback, AzureADType
|
from authentik.sources.oauth.types.entra_id import EntraIDOAuthCallback, EntraIDType
|
||||||
|
|
||||||
# https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2
|
# https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2
|
||||||
AAD_USER = {
|
EID_USER = {
|
||||||
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users/$entity",
|
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users/$entity",
|
||||||
"@odata.id": (
|
"@odata.id": (
|
||||||
"https://graph.microsoft.com/v2/7ce9b89e-646a-41d2-9fa6-8371c6a8423d/"
|
"https://graph.microsoft.com/v2/7ce9b89e-646a-41d2-9fa6-8371c6a8423d/"
|
||||||
@@ -41,11 +41,11 @@ class TestTypeAzureAD(TestCase):
|
|||||||
|
|
||||||
def test_enroll_context(self):
|
def test_enroll_context(self):
|
||||||
"""Test azure_ad Enrollment context"""
|
"""Test azure_ad Enrollment context"""
|
||||||
ak_context = AzureADType().get_base_user_properties(source=self.source, info=AAD_USER)
|
ak_context = EntraIDType().get_base_user_properties(source=self.source, info=EID_USER)
|
||||||
self.assertEqual(ak_context["username"], AAD_USER["userPrincipalName"])
|
self.assertEqual(ak_context["username"], EID_USER["userPrincipalName"])
|
||||||
self.assertEqual(ak_context["email"], AAD_USER["mail"])
|
self.assertEqual(ak_context["email"], EID_USER["mail"])
|
||||||
self.assertEqual(ak_context["name"], AAD_USER["displayName"])
|
self.assertEqual(ak_context["name"], EID_USER["displayName"])
|
||||||
|
|
||||||
def test_user_id(self):
|
def test_user_id(self):
|
||||||
"""Test azure AD user ID"""
|
"""Test Entra ID user ID"""
|
||||||
self.assertEqual(AzureADOAuthCallback().get_user_id(AAD_USER), AAD_USER["id"])
|
self.assertEqual(EntraIDOAuthCallback().get_user_id(EID_USER), EID_USER["id"])
|
||||||
@@ -1,105 +1,17 @@
|
|||||||
"""AzureAD OAuth2 Views"""
|
"""AzureAD OAuth2 Views"""
|
||||||
|
|
||||||
from typing import Any
|
from authentik.sources.oauth.types.entra_id import EntraIDType
|
||||||
|
from authentik.sources.oauth.types.registry import registry
|
||||||
|
|
||||||
from requests import RequestException
|
# TODO: When removing this, add a migration for OAuthSource that sets
|
||||||
from structlog.stdlib import get_logger
|
# provider_type to `entraid` if it is currently `azuread`
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
|
||||||
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
|
||||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
|
||||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
|
||||||
|
|
||||||
LOGGER = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class AzureADOAuthRedirect(OAuthRedirect):
|
|
||||||
"""Azure AD OAuth2 Redirect"""
|
|
||||||
|
|
||||||
def get_additional_parameters(self, source): # pragma: no cover
|
|
||||||
return {
|
|
||||||
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AzureADClient(UserprofileHeaderAuthClient):
|
|
||||||
"""Fetch AzureAD group information"""
|
|
||||||
|
|
||||||
def get_profile_info(self, token):
|
|
||||||
profile_data = super().get_profile_info(token)
|
|
||||||
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
|
|
||||||
return profile_data
|
|
||||||
group_response = self.session.request(
|
|
||||||
"get",
|
|
||||||
"https://graph.microsoft.com/v1.0/me/memberOf",
|
|
||||||
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
group_response.raise_for_status()
|
|
||||||
except RequestException as exc:
|
|
||||||
LOGGER.warning(
|
|
||||||
"Unable to fetch user profile",
|
|
||||||
exc=exc,
|
|
||||||
response=exc.response.text if exc.response else str(exc),
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
profile_data["raw_groups"] = group_response.json()
|
|
||||||
return profile_data
|
|
||||||
|
|
||||||
|
|
||||||
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
|
|
||||||
"""AzureAD OAuth2 Callback"""
|
|
||||||
|
|
||||||
client_class = AzureADClient
|
|
||||||
|
|
||||||
def get_user_id(self, info: dict[str, str]) -> str:
|
|
||||||
# Default try to get `id` for the Graph API endpoint
|
|
||||||
# fallback to OpenID logic in case the profile URL was changed
|
|
||||||
return info.get("id", super().get_user_id(info))
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register()
|
||||||
class AzureADType(SourceType):
|
class AzureADType(EntraIDType):
|
||||||
"""Azure AD Type definition"""
|
"""Azure AD Type definition"""
|
||||||
|
|
||||||
callback_view = AzureADOAuthCallback
|
|
||||||
redirect_view = AzureADOAuthRedirect
|
|
||||||
verbose_name = "Azure AD"
|
verbose_name = "Azure AD"
|
||||||
name = "azuread"
|
name = "azuread"
|
||||||
|
|
||||||
urls_customizable = True
|
urls_customizable = True
|
||||||
|
|
||||||
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
|
||||||
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
|
|
||||||
profile_url = "https://graph.microsoft.com/v1.0/me"
|
|
||||||
oidc_well_known_url = (
|
|
||||||
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
|
|
||||||
)
|
|
||||||
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
|
||||||
|
|
||||||
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
|
||||||
|
|
||||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
|
||||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
|
||||||
# Format group info
|
|
||||||
groups = []
|
|
||||||
group_id_dict = {}
|
|
||||||
for group in info.get("raw_groups", {}).get("value", []):
|
|
||||||
if group["@odata.type"] != "#microsoft.graph.group":
|
|
||||||
continue
|
|
||||||
groups.append(group["id"])
|
|
||||||
group_id_dict[group["id"]] = group
|
|
||||||
info["raw_groups"] = group_id_dict
|
|
||||||
return {
|
|
||||||
"username": info.get("userPrincipalName"),
|
|
||||||
"email": mail,
|
|
||||||
"name": info.get("displayName"),
|
|
||||||
"groups": groups,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_base_group_properties(self, source, group_id, **kwargs):
|
|
||||||
raw_group = kwargs["info"]["raw_groups"][group_id]
|
|
||||||
return {
|
|
||||||
"name": raw_group["displayName"],
|
|
||||||
}
|
|
||||||
|
|||||||
102
authentik/sources/oauth/types/entra_id.py
Normal file
102
authentik/sources/oauth/types/entra_id.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""EntraID OAuth2 Views"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from requests import RequestException
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||||
|
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||||
|
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||||
|
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||||
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class EntraIDOAuthRedirect(OAuthRedirect):
|
||||||
|
"""Entra ID OAuth2 Redirect"""
|
||||||
|
|
||||||
|
def get_additional_parameters(self, source): # pragma: no cover
|
||||||
|
return {
|
||||||
|
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EntraIDClient(UserprofileHeaderAuthClient):
|
||||||
|
"""Fetch EntraID group information"""
|
||||||
|
|
||||||
|
def get_profile_info(self, token):
|
||||||
|
profile_data = super().get_profile_info(token)
|
||||||
|
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
|
||||||
|
return profile_data
|
||||||
|
group_response = self.session.request(
|
||||||
|
"get",
|
||||||
|
"https://graph.microsoft.com/v1.0/me/memberOf",
|
||||||
|
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
group_response.raise_for_status()
|
||||||
|
except RequestException as exc:
|
||||||
|
LOGGER.warning(
|
||||||
|
"Unable to fetch user profile",
|
||||||
|
exc=exc,
|
||||||
|
response=exc.response.text if exc.response else str(exc),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
profile_data["raw_groups"] = group_response.json()
|
||||||
|
return profile_data
|
||||||
|
|
||||||
|
|
||||||
|
class EntraIDOAuthCallback(OpenIDConnectOAuth2Callback):
|
||||||
|
"""EntraID OAuth2 Callback"""
|
||||||
|
|
||||||
|
client_class = EntraIDClient
|
||||||
|
|
||||||
|
def get_user_id(self, info: dict[str, str]) -> str:
|
||||||
|
# Default try to get `id` for the Graph API endpoint
|
||||||
|
# fallback to OpenID logic in case the profile URL was changed
|
||||||
|
return info.get("id", super().get_user_id(info))
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register()
|
||||||
|
class EntraIDType(SourceType):
|
||||||
|
"""Entra ID Type definition"""
|
||||||
|
|
||||||
|
callback_view = EntraIDOAuthCallback
|
||||||
|
redirect_view = EntraIDOAuthRedirect
|
||||||
|
verbose_name = "Entra ID"
|
||||||
|
name = "entraid"
|
||||||
|
|
||||||
|
urls_customizable = True
|
||||||
|
|
||||||
|
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||||
|
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
|
||||||
|
profile_url = "https://graph.microsoft.com/v1.0/me"
|
||||||
|
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
||||||
|
|
||||||
|
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||||
|
|
||||||
|
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||||
|
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||||
|
# Format group info
|
||||||
|
groups = []
|
||||||
|
group_id_dict = {}
|
||||||
|
for group in info.get("raw_groups", {}).get("value", []):
|
||||||
|
if group["@odata.type"] != "#microsoft.graph.group":
|
||||||
|
continue
|
||||||
|
groups.append(group["id"])
|
||||||
|
group_id_dict[group["id"]] = group
|
||||||
|
info["raw_groups"] = group_id_dict
|
||||||
|
return {
|
||||||
|
"username": info.get("userPrincipalName"),
|
||||||
|
"email": mail,
|
||||||
|
"name": info.get("displayName"),
|
||||||
|
"groups": groups,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_base_group_properties(self, source, group_id, **kwargs):
|
||||||
|
raw_group = kwargs["info"]["raw_groups"][group_id]
|
||||||
|
return {
|
||||||
|
"name": raw_group["displayName"],
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ class SCIMSourceGroupSerializer(SourceSerializer):
|
|||||||
model = SCIMSourceGroup
|
model = SCIMSourceGroup
|
||||||
fields = [
|
fields = [
|
||||||
"id",
|
"id",
|
||||||
|
"external_id",
|
||||||
"group",
|
"group",
|
||||||
"group_obj",
|
"group_obj",
|
||||||
"source",
|
"source",
|
||||||
@@ -31,5 +32,5 @@ class SCIMSourceGroupViewSet(UsedByMixin, ModelViewSet):
|
|||||||
queryset = SCIMSourceGroup.objects.all().select_related("group")
|
queryset = SCIMSourceGroup.objects.all().select_related("group")
|
||||||
serializer_class = SCIMSourceGroupSerializer
|
serializer_class = SCIMSourceGroupSerializer
|
||||||
filterset_fields = ["source__slug", "group__name", "group__group_uuid"]
|
filterset_fields = ["source__slug", "group__name", "group__group_uuid"]
|
||||||
search_fields = ["source__slug", "group__name", "attributes"]
|
search_fields = ["source__slug", "group__name", "attributes", "external_id"]
|
||||||
ordering = ["group__name"]
|
ordering = ["group__name"]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class SCIMSourceUserSerializer(SourceSerializer):
|
|||||||
model = SCIMSourceUser
|
model = SCIMSourceUser
|
||||||
fields = [
|
fields = [
|
||||||
"id",
|
"id",
|
||||||
|
"external_id",
|
||||||
"user",
|
"user",
|
||||||
"user_obj",
|
"user_obj",
|
||||||
"source",
|
"source",
|
||||||
@@ -31,5 +32,5 @@ class SCIMSourceUserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
queryset = SCIMSourceUser.objects.all().select_related("user")
|
queryset = SCIMSourceUser.objects.all().select_related("user")
|
||||||
serializer_class = SCIMSourceUserSerializer
|
serializer_class = SCIMSourceUserSerializer
|
||||||
filterset_fields = ["source__slug", "user__username", "user__id"]
|
filterset_fields = ["source__slug", "user__username", "user__id"]
|
||||||
search_fields = ["source__slug", "user__username", "attributes"]
|
search_fields = ["source__slug", "user__username", "attributes", "user__uuid", "external_id"]
|
||||||
ordering = ["user__username"]
|
ordering = ["user__username"]
|
||||||
|
|||||||
4
authentik/sources/scim/constants.py
Normal file
4
authentik/sources/scim/constants.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
SCIM_URN_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||||
|
SCIM_URN_GROUP = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||||
|
SCIM_URN_USER = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||||
|
SCIM_URN_USER_ENTERPRISE = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""SCIM Errors"""
|
|
||||||
|
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
|
||||||
|
|
||||||
|
|
||||||
class PatchError(SentryIgnoredException):
|
|
||||||
"""Error raised within an atomic block when an error happened
|
|
||||||
so nothing is saved"""
|
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# Generated by Django 5.1.11 on 2025-07-13 01:07
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from django.db import migrations, models
|
||||||
|
from django.apps.registry import Apps
|
||||||
|
|
||||||
|
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_ext_id(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
|
SCIMSourceUser = apps.get_model("authentik_sources_scim", "SCIMSourceUser")
|
||||||
|
SCIMSourceGroup = apps.get_model("authentik_sources_scim", "SCIMSourceGroup")
|
||||||
|
db_alias = schema_editor.connection.alias
|
||||||
|
for user in SCIMSourceUser.objects.using(db_alias).all():
|
||||||
|
user.external_id = user.id
|
||||||
|
user.save(update_fields=["external_id"])
|
||||||
|
for group in SCIMSourceGroup.objects.using(db_alias).all():
|
||||||
|
group.external_id = group.id
|
||||||
|
group.save(update_fields=["external_id"])
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_sources_scim", "0002_scimsourcepropertymapping"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name="scimsourcegroup",
|
||||||
|
unique_together=set(),
|
||||||
|
),
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name="scimsourceuser",
|
||||||
|
unique_together=set(),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="scimsourcegroup",
|
||||||
|
name="external_id",
|
||||||
|
field=models.TextField(default=None, null=True),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="scimsourceuser",
|
||||||
|
name="external_id",
|
||||||
|
field=models.TextField(default=None, null=True),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name="scimsourcegroup",
|
||||||
|
unique_together={("external_id", "source")},
|
||||||
|
),
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name="scimsourceuser",
|
||||||
|
unique_together={("external_id", "source")},
|
||||||
|
),
|
||||||
|
migrations.RunPython(migrate_ext_id, migrations.RunPython.noop),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="scimsourcegroup",
|
||||||
|
name="external_id",
|
||||||
|
field=models.TextField(),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="scimsourceuser",
|
||||||
|
name="external_id",
|
||||||
|
field=models.TextField(),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.AddIndex(
|
||||||
|
model_name="scimsourcegroup",
|
||||||
|
index=models.Index(fields=["external_id"], name="authentik_s_externa_05e346_idx"),
|
||||||
|
),
|
||||||
|
migrations.AddIndex(
|
||||||
|
model_name="scimsourceuser",
|
||||||
|
index=models.Index(fields=["external_id"], name="authentik_s_externa_4bd760_idx"),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="scimsourcegroup",
|
||||||
|
name="id",
|
||||||
|
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="scimsourceuser",
|
||||||
|
name="id",
|
||||||
|
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="scimsourcegroup",
|
||||||
|
name="last_update",
|
||||||
|
field=models.DateTimeField(auto_now=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="scimsourceuser",
|
||||||
|
name="last_update",
|
||||||
|
field=models.DateTimeField(auto_now=True),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""SCIM Source"""
|
"""SCIM Source"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
@@ -103,10 +104,12 @@ class SCIMSourcePropertyMapping(PropertyMapping):
|
|||||||
class SCIMSourceUser(SerializerModel):
|
class SCIMSourceUser(SerializerModel):
|
||||||
"""Mapping of a user and source to a SCIM user ID"""
|
"""Mapping of a user and source to a SCIM user ID"""
|
||||||
|
|
||||||
id = models.TextField(primary_key=True)
|
id = models.TextField(primary_key=True, default=uuid4)
|
||||||
|
external_id = models.TextField()
|
||||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||||
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
||||||
attributes = models.JSONField(default=dict)
|
attributes = models.JSONField(default=dict)
|
||||||
|
last_update = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> BaseSerializer:
|
def serializer(self) -> BaseSerializer:
|
||||||
@@ -115,7 +118,10 @@ class SCIMSourceUser(SerializerModel):
|
|||||||
return SCIMSourceUserSerializer
|
return SCIMSourceUserSerializer
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
unique_together = (("id", "user", "source"),)
|
unique_together = (("external_id", "source"),)
|
||||||
|
indexes = [
|
||||||
|
models.Index(fields=["external_id"]),
|
||||||
|
]
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"SCIM User {self.user_id} to {self.source_id}"
|
return f"SCIM User {self.user_id} to {self.source_id}"
|
||||||
@@ -124,10 +130,12 @@ class SCIMSourceUser(SerializerModel):
|
|||||||
class SCIMSourceGroup(SerializerModel):
|
class SCIMSourceGroup(SerializerModel):
|
||||||
"""Mapping of a group and source to a SCIM user ID"""
|
"""Mapping of a group and source to a SCIM user ID"""
|
||||||
|
|
||||||
id = models.TextField(primary_key=True)
|
id = models.TextField(primary_key=True, default=uuid4)
|
||||||
|
external_id = models.TextField()
|
||||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||||
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
||||||
attributes = models.JSONField(default=dict)
|
attributes = models.JSONField(default=dict)
|
||||||
|
last_update = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> BaseSerializer:
|
def serializer(self) -> BaseSerializer:
|
||||||
@@ -136,7 +144,10 @@ class SCIMSourceGroup(SerializerModel):
|
|||||||
return SCIMSourceGroupSerializer
|
return SCIMSourceGroupSerializer
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
unique_together = (("id", "group", "source"),)
|
unique_together = (("external_id", "source"),)
|
||||||
|
indexes = [
|
||||||
|
models.Index(fields=["external_id"]),
|
||||||
|
]
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"SCIM Group {self.group_id} to {self.source_id}"
|
return f"SCIM Group {self.group_id} to {self.source_id}"
|
||||||
|
|||||||
0
authentik/sources/scim/patch/__init__.py
Normal file
0
authentik/sources/scim/patch/__init__.py
Normal file
180
authentik/sources/scim/patch/lexer.py
Normal file
180
authentik/sources/scim/patch/lexer.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from authentik.sources.scim.constants import (
|
||||||
|
SCIM_URN_GROUP,
|
||||||
|
SCIM_URN_SCHEMA,
|
||||||
|
SCIM_URN_USER,
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Token types for SCIM path parsing
|
||||||
|
class TokenType(Enum):
|
||||||
|
ATTRIBUTE = "ATTRIBUTE"
|
||||||
|
DOT = "DOT"
|
||||||
|
LBRACKET = "LBRACKET"
|
||||||
|
RBRACKET = "RBRACKET"
|
||||||
|
LPAREN = "LPAREN"
|
||||||
|
RPAREN = "RPAREN"
|
||||||
|
STRING = "STRING"
|
||||||
|
NUMBER = "NUMBER"
|
||||||
|
BOOLEAN = "BOOLEAN"
|
||||||
|
NULL = "NULL"
|
||||||
|
OPERATOR = "OPERATOR"
|
||||||
|
AND = "AND"
|
||||||
|
OR = "OR"
|
||||||
|
NOT = "NOT"
|
||||||
|
EOF = "EOF"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Token:
|
||||||
|
type: TokenType
|
||||||
|
value: str
|
||||||
|
position: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMPathLexer:
|
||||||
|
"""Lexer for SCIM paths and filter expressions"""
|
||||||
|
|
||||||
|
OPERATORS = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||||
|
|
||||||
|
def __init__(self, text: str):
|
||||||
|
self.schema_urns = [
|
||||||
|
SCIM_URN_SCHEMA,
|
||||||
|
SCIM_URN_GROUP,
|
||||||
|
SCIM_URN_USER,
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
]
|
||||||
|
self.text = text
|
||||||
|
self.pos = 0
|
||||||
|
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
|
||||||
|
|
||||||
|
def advance(self):
|
||||||
|
"""Move to next character"""
|
||||||
|
self.pos += 1
|
||||||
|
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
|
||||||
|
|
||||||
|
def skip_whitespace(self):
|
||||||
|
"""Skip whitespace characters"""
|
||||||
|
while self.current_char and self.current_char.isspace():
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
def read_string(self, quote_char):
|
||||||
|
"""Read a quoted string"""
|
||||||
|
value = ""
|
||||||
|
self.advance() # Skip opening quote
|
||||||
|
|
||||||
|
while self.current_char and self.current_char != quote_char:
|
||||||
|
if self.current_char == "\\":
|
||||||
|
self.advance()
|
||||||
|
if self.current_char:
|
||||||
|
value += self.current_char
|
||||||
|
self.advance()
|
||||||
|
else:
|
||||||
|
value += self.current_char
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
if self.current_char == quote_char:
|
||||||
|
self.advance() # Skip closing quote
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def read_number(self):
|
||||||
|
"""Read a number (integer or float)"""
|
||||||
|
value = ""
|
||||||
|
while self.current_char and (self.current_char.isdigit() or self.current_char == "."):
|
||||||
|
value += self.current_char
|
||||||
|
self.advance()
|
||||||
|
return value
|
||||||
|
|
||||||
|
def read_identifier(self):
|
||||||
|
"""Read an identifier (attribute name or operator) - supports URN format"""
|
||||||
|
value = ""
|
||||||
|
while self.current_char and (self.current_char.isalnum() or self.current_char in "_-:"):
|
||||||
|
value += self.current_char
|
||||||
|
self.advance()
|
||||||
|
# If the identifier value so far is a schema URN, take that as the identifier and
|
||||||
|
# treat the next part as a sub_attribute
|
||||||
|
if value in self.schema_urns:
|
||||||
|
self.current_char = "."
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Handle dots within URN identifiers (like "2.0")
|
||||||
|
# A dot is part of the identifier if it's followed by a digit
|
||||||
|
if (
|
||||||
|
self.current_char == "."
|
||||||
|
and self.pos + 1 < len(self.text)
|
||||||
|
and self.text[self.pos + 1].isdigit()
|
||||||
|
):
|
||||||
|
value += self.current_char
|
||||||
|
self.advance()
|
||||||
|
# Continue reading digits after the dot
|
||||||
|
while self.current_char and self.current_char.isdigit():
|
||||||
|
value += self.current_char
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_next_token(self) -> Token: # noqa PLR0911
|
||||||
|
"""Get the next token from the input"""
|
||||||
|
while self.current_char:
|
||||||
|
if self.current_char.isspace():
|
||||||
|
self.skip_whitespace()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.current_char == ".":
|
||||||
|
self.advance()
|
||||||
|
return Token(TokenType.DOT, ".")
|
||||||
|
|
||||||
|
if self.current_char == "[":
|
||||||
|
self.advance()
|
||||||
|
return Token(TokenType.LBRACKET, "[")
|
||||||
|
|
||||||
|
if self.current_char == "]":
|
||||||
|
self.advance()
|
||||||
|
return Token(TokenType.RBRACKET, "]")
|
||||||
|
|
||||||
|
if self.current_char == "(":
|
||||||
|
self.advance()
|
||||||
|
return Token(TokenType.LPAREN, "(")
|
||||||
|
|
||||||
|
if self.current_char == ")":
|
||||||
|
self.advance()
|
||||||
|
return Token(TokenType.RPAREN, ")")
|
||||||
|
|
||||||
|
if self.current_char in "\"'":
|
||||||
|
quote_char = self.current_char
|
||||||
|
value = self.read_string(quote_char)
|
||||||
|
return Token(TokenType.STRING, value)
|
||||||
|
|
||||||
|
if self.current_char.isdigit():
|
||||||
|
value = self.read_number()
|
||||||
|
return Token(TokenType.NUMBER, value)
|
||||||
|
|
||||||
|
if self.current_char.isalpha() or self.current_char == "_":
|
||||||
|
value = self.read_identifier()
|
||||||
|
|
||||||
|
# Check for special keywords
|
||||||
|
if value.lower() == "true":
|
||||||
|
return Token(TokenType.BOOLEAN, True)
|
||||||
|
elif value.lower() == "false":
|
||||||
|
return Token(TokenType.BOOLEAN, False)
|
||||||
|
elif value.lower() == "null":
|
||||||
|
return Token(TokenType.NULL, None)
|
||||||
|
elif value.lower() == "and":
|
||||||
|
return Token(TokenType.AND, "and")
|
||||||
|
elif value.lower() == "or":
|
||||||
|
return Token(TokenType.OR, "or")
|
||||||
|
elif value.lower() == "not":
|
||||||
|
return Token(TokenType.NOT, "not")
|
||||||
|
elif value.lower() in self.OPERATORS:
|
||||||
|
return Token(TokenType.OPERATOR, value.lower())
|
||||||
|
else:
|
||||||
|
return Token(TokenType.ATTRIBUTE, value)
|
||||||
|
|
||||||
|
# Skip unknown characters
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
return Token(TokenType.EOF, "")
|
||||||
131
authentik/sources/scim/patch/parser.py
Normal file
131
authentik/sources/scim/patch/parser.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from authentik.sources.scim.patch.lexer import SCIMPathLexer, TokenType
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMPathParser:
|
||||||
|
"""Parser for SCIM paths including filter expressions"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.lexer = None
|
||||||
|
self.current_token = None
|
||||||
|
|
||||||
|
def parse_path(self, path: str | None) -> list[dict[str, Any]]:
|
||||||
|
"""Parse a SCIM path into components"""
|
||||||
|
self.lexer = SCIMPathLexer(path)
|
||||||
|
self.current_token = self.lexer.get_next_token()
|
||||||
|
|
||||||
|
components = []
|
||||||
|
|
||||||
|
while self.current_token.type != TokenType.EOF:
|
||||||
|
component = self._parse_path_component()
|
||||||
|
if component:
|
||||||
|
components.append(component)
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
def _parse_path_component(self) -> dict[str, Any] | None:
|
||||||
|
"""Parse a single path component"""
|
||||||
|
if self.current_token.type != TokenType.ATTRIBUTE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
attribute = self.current_token.value
|
||||||
|
self._consume(TokenType.ATTRIBUTE)
|
||||||
|
|
||||||
|
filter_expr = None
|
||||||
|
sub_attribute = None
|
||||||
|
|
||||||
|
# Check for filter expression
|
||||||
|
if self.current_token.type == TokenType.LBRACKET:
|
||||||
|
self._consume(TokenType.LBRACKET)
|
||||||
|
filter_expr = self._parse_filter_expression()
|
||||||
|
self._consume(TokenType.RBRACKET)
|
||||||
|
|
||||||
|
# Check for sub-attribute
|
||||||
|
if self.current_token.type == TokenType.DOT:
|
||||||
|
self._consume(TokenType.DOT)
|
||||||
|
if self.current_token.type == TokenType.ATTRIBUTE:
|
||||||
|
sub_attribute = self.current_token.value
|
||||||
|
self._consume(TokenType.ATTRIBUTE)
|
||||||
|
|
||||||
|
return {"attribute": attribute, "filter": filter_expr, "sub_attribute": sub_attribute}
|
||||||
|
|
||||||
|
def _parse_filter_expression(self) -> dict[str, Any] | None:
|
||||||
|
"""Parse a filter expression like 'primary eq true' or
|
||||||
|
'type eq "work" and primary eq true'"""
|
||||||
|
return self._parse_or_expression()
|
||||||
|
|
||||||
|
def _parse_or_expression(self) -> dict[str, Any] | None:
|
||||||
|
"""Parse OR expressions"""
|
||||||
|
left = self._parse_and_expression()
|
||||||
|
|
||||||
|
while self.current_token.type == TokenType.OR:
|
||||||
|
self._consume(TokenType.OR)
|
||||||
|
right = self._parse_and_expression()
|
||||||
|
left = {"type": "logical", "operator": "or", "left": left, "right": right}
|
||||||
|
|
||||||
|
return left
|
||||||
|
|
||||||
|
def _parse_and_expression(self) -> dict[str, Any] | None:
|
||||||
|
"""Parse AND expressions"""
|
||||||
|
left = self._parse_primary_expression()
|
||||||
|
|
||||||
|
while self.current_token.type == TokenType.AND:
|
||||||
|
self._consume(TokenType.AND)
|
||||||
|
right = self._parse_primary_expression()
|
||||||
|
left = {"type": "logical", "operator": "and", "left": left, "right": right}
|
||||||
|
|
||||||
|
return left
|
||||||
|
|
||||||
|
def _parse_primary_expression(self) -> dict[str, Any] | None:
|
||||||
|
"""Parse primary expressions (attribute operator value)"""
|
||||||
|
if self.current_token.type == TokenType.LPAREN:
|
||||||
|
self._consume(TokenType.LPAREN)
|
||||||
|
expr = self._parse_or_expression()
|
||||||
|
self._consume(TokenType.RPAREN)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
if self.current_token.type == TokenType.NOT:
|
||||||
|
self._consume(TokenType.NOT)
|
||||||
|
expr = self._parse_primary_expression()
|
||||||
|
return {"type": "logical", "operator": "not", "operand": expr}
|
||||||
|
|
||||||
|
if self.current_token.type != TokenType.ATTRIBUTE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
attribute = self.current_token.value
|
||||||
|
self._consume(TokenType.ATTRIBUTE)
|
||||||
|
|
||||||
|
if self.current_token.type != TokenType.OPERATOR:
|
||||||
|
return None
|
||||||
|
|
||||||
|
operator = self.current_token.value
|
||||||
|
self._consume(TokenType.OPERATOR)
|
||||||
|
|
||||||
|
# Parse value
|
||||||
|
value = None
|
||||||
|
if self.current_token.type == TokenType.STRING:
|
||||||
|
value = self.current_token.value
|
||||||
|
self._consume(TokenType.STRING)
|
||||||
|
elif self.current_token.type == TokenType.NUMBER:
|
||||||
|
value = (
|
||||||
|
float(self.current_token.value)
|
||||||
|
if "." in self.current_token.value
|
||||||
|
else int(self.current_token.value)
|
||||||
|
)
|
||||||
|
self._consume(TokenType.NUMBER)
|
||||||
|
elif self.current_token.type == TokenType.BOOLEAN:
|
||||||
|
value = self.current_token.value
|
||||||
|
self._consume(TokenType.BOOLEAN)
|
||||||
|
elif self.current_token.type == TokenType.NULL:
|
||||||
|
value = None
|
||||||
|
self._consume(TokenType.NULL)
|
||||||
|
|
||||||
|
return {"type": "comparison", "attribute": attribute, "operator": operator, "value": value}
|
||||||
|
|
||||||
|
def _consume(self, expected_type: TokenType):
|
||||||
|
"""Consume a token of the expected type"""
|
||||||
|
if self.current_token.type == expected_type:
|
||||||
|
self.current_token = self.lexer.get_next_token()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected {expected_type}, got {self.current_token.type}")
|
||||||
246
authentik/sources/scim/patch/processor.py
Normal file
246
authentik/sources/scim/patch/processor.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from authentik.providers.scim.clients.schema import PatchOp, PatchOperation
|
||||||
|
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||||
|
from authentik.sources.scim.patch.parser import SCIMPathParser
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMPatchProcessor:
|
||||||
|
"""Processes SCIM patch operations on Python dictionaries"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = SCIMPathParser()
|
||||||
|
|
||||||
|
def apply_patches(self, data: dict[str, Any], patches: list[PatchOperation]) -> dict[str, Any]:
|
||||||
|
"""Apply a list of patch operations to the data"""
|
||||||
|
result = data.copy()
|
||||||
|
|
||||||
|
for _patch in patches:
|
||||||
|
patch = PatchOperation.model_validate(_patch)
|
||||||
|
if patch.path is None:
|
||||||
|
# Handle operations with no path - value contains attribute paths as keys
|
||||||
|
self._apply_bulk_operation(result, patch.op, patch.value)
|
||||||
|
elif patch.op == PatchOp.add:
|
||||||
|
self._apply_add(result, patch.path, patch.value)
|
||||||
|
elif patch.op == PatchOp.remove:
|
||||||
|
self._apply_remove(result, patch.path)
|
||||||
|
elif patch.op == PatchOp.replace:
|
||||||
|
self._apply_replace(result, patch.path, patch.value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _apply_bulk_operation(
|
||||||
|
self, data: dict[str, Any], operation: PatchOp, value: dict[str, Any]
|
||||||
|
):
|
||||||
|
"""Apply bulk operations when path is None"""
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
return
|
||||||
|
for path, val in value.items():
|
||||||
|
if operation == PatchOp.add:
|
||||||
|
self._apply_add(data, path, val)
|
||||||
|
elif operation == PatchOp.remove:
|
||||||
|
self._apply_remove(data, path)
|
||||||
|
elif operation == PatchOp.replace:
|
||||||
|
self._apply_replace(data, path, val)
|
||||||
|
|
||||||
|
def _apply_add(self, data: dict[str, Any], path: str, value: Any):
|
||||||
|
"""Apply ADD operation"""
|
||||||
|
components = self.parser.parse_path(path)
|
||||||
|
|
||||||
|
if len(components) == 1 and not components[0]["filter"]:
|
||||||
|
# Simple path
|
||||||
|
attr = components[0]["attribute"]
|
||||||
|
if components[0]["sub_attribute"]:
|
||||||
|
if attr not in data:
|
||||||
|
data[attr] = {}
|
||||||
|
# Somewhat hacky workaround for the manager attribute of the enterprise schema
|
||||||
|
# ideally we'd do this based on the schema
|
||||||
|
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
|
||||||
|
data[attr][components[0]["sub_attribute"]] = {"value": value}
|
||||||
|
else:
|
||||||
|
data[attr][components[0]["sub_attribute"]] = value
|
||||||
|
elif attr in data:
|
||||||
|
data[attr].append(value)
|
||||||
|
else:
|
||||||
|
data[attr] = value
|
||||||
|
else:
|
||||||
|
# Complex path with filters
|
||||||
|
self._navigate_and_modify(data, components, value, "add")
|
||||||
|
|
||||||
|
def _apply_remove(self, data: dict[str, Any], path: str):
|
||||||
|
"""Apply REMOVE operation"""
|
||||||
|
components = self.parser.parse_path(path)
|
||||||
|
|
||||||
|
if len(components) == 1 and not components[0]["filter"]:
|
||||||
|
# Simple path
|
||||||
|
attr = components[0]["attribute"]
|
||||||
|
if components[0]["sub_attribute"]:
|
||||||
|
if attr in data and isinstance(data[attr], dict):
|
||||||
|
data[attr].pop(components[0]["sub_attribute"], None)
|
||||||
|
else:
|
||||||
|
data.pop(attr, None)
|
||||||
|
else:
|
||||||
|
# Complex path with filters
|
||||||
|
self._navigate_and_modify(data, components, None, "remove")
|
||||||
|
|
||||||
|
def _apply_replace(self, data: dict[str, Any], path: str, value: Any):
|
||||||
|
"""Apply REPLACE operation"""
|
||||||
|
components = self.parser.parse_path(path)
|
||||||
|
|
||||||
|
if len(components) == 1 and not components[0]["filter"]:
|
||||||
|
# Simple path
|
||||||
|
attr = components[0]["attribute"]
|
||||||
|
if components[0]["sub_attribute"]:
|
||||||
|
if attr not in data:
|
||||||
|
data[attr] = {}
|
||||||
|
# Somewhat hacky workaround for the manager attribute of the enterprise schema
|
||||||
|
# ideally we'd do this based on the schema
|
||||||
|
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
|
||||||
|
data[attr][components[0]["sub_attribute"]] = {"value": value}
|
||||||
|
else:
|
||||||
|
data[attr][components[0]["sub_attribute"]] = value
|
||||||
|
else:
|
||||||
|
data[attr] = value
|
||||||
|
else:
|
||||||
|
# Complex path with filters
|
||||||
|
self._navigate_and_modify(data, components, value, "replace")
|
||||||
|
|
||||||
|
def _navigate_and_modify( # noqa PLR0912
|
||||||
|
self, data: dict[str, Any], components: list[dict[str, Any]], value: Any, operation: str
|
||||||
|
):
|
||||||
|
"""Navigate through complex paths and apply modifications"""
|
||||||
|
current = data
|
||||||
|
|
||||||
|
for i, component in enumerate(components):
|
||||||
|
attr = component["attribute"]
|
||||||
|
filter_expr = component["filter"]
|
||||||
|
sub_attr = component["sub_attribute"]
|
||||||
|
|
||||||
|
if filter_expr:
|
||||||
|
# Handle array with filter
|
||||||
|
if attr not in current:
|
||||||
|
if operation == "add":
|
||||||
|
current[attr] = []
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(current[attr], list):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find matching items
|
||||||
|
matching_items = []
|
||||||
|
for item in current[attr]:
|
||||||
|
if self._matches_filter(item, filter_expr):
|
||||||
|
matching_items.append(item)
|
||||||
|
|
||||||
|
if not matching_items and operation == "add":
|
||||||
|
# Create new item if none match (only for simple comparison filters)
|
||||||
|
if filter_expr.get("type", "comparison") == "comparison":
|
||||||
|
new_item = {filter_expr["attribute"]: filter_expr["value"]}
|
||||||
|
current[attr].append(new_item)
|
||||||
|
matching_items = [new_item]
|
||||||
|
|
||||||
|
# Apply operation to matching items
|
||||||
|
for item in matching_items:
|
||||||
|
if sub_attr:
|
||||||
|
if operation in {"add", "replace"}:
|
||||||
|
item[sub_attr] = value
|
||||||
|
elif operation == "remove":
|
||||||
|
item.pop(sub_attr, None)
|
||||||
|
elif operation in {"add", "replace"}:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
item.update(value)
|
||||||
|
else:
|
||||||
|
# If value is not a dict, we can't merge it
|
||||||
|
pass
|
||||||
|
elif operation == "remove":
|
||||||
|
# Remove the entire item
|
||||||
|
if item in current[attr]:
|
||||||
|
current[attr].remove(item)
|
||||||
|
# Handle simple attribute
|
||||||
|
elif i == len(components) - 1:
|
||||||
|
# Last component
|
||||||
|
if sub_attr:
|
||||||
|
if attr not in current:
|
||||||
|
current[attr] = {}
|
||||||
|
if operation in {"add", "replace"}:
|
||||||
|
current[attr][sub_attr] = value
|
||||||
|
elif operation == "remove":
|
||||||
|
current[attr].pop(sub_attr, None)
|
||||||
|
elif operation in {"add", "replace"}:
|
||||||
|
current[attr] = value
|
||||||
|
elif operation == "remove":
|
||||||
|
current.pop(attr, None)
|
||||||
|
else:
|
||||||
|
# Navigate deeper
|
||||||
|
if attr not in current:
|
||||||
|
current[attr] = {}
|
||||||
|
current = current[attr]
|
||||||
|
|
||||||
|
def _matches_filter(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
|
||||||
|
"""Check if an item matches the filter expression"""
|
||||||
|
if not filter_expr:
|
||||||
|
return True
|
||||||
|
|
||||||
|
filter_type = filter_expr.get("type", "comparison")
|
||||||
|
|
||||||
|
if filter_type == "comparison":
|
||||||
|
return self._matches_comparison(item, filter_expr)
|
||||||
|
elif filter_type == "logical":
|
||||||
|
return self._matches_logical(item, filter_expr)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _matches_comparison( # noqa PLR0912
|
||||||
|
self, item: dict[str, Any], filter_expr: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""Check if an item matches a comparison filter"""
|
||||||
|
attr = filter_expr["attribute"]
|
||||||
|
operator = filter_expr["operator"]
|
||||||
|
expected_value = filter_expr["value"]
|
||||||
|
|
||||||
|
if attr not in item:
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_value = item[attr]
|
||||||
|
|
||||||
|
if operator == "eq":
|
||||||
|
return actual_value == expected_value
|
||||||
|
elif operator == "ne":
|
||||||
|
return actual_value != expected_value
|
||||||
|
elif operator == "co":
|
||||||
|
return str(expected_value) in str(actual_value)
|
||||||
|
elif operator == "sw":
|
||||||
|
return str(actual_value).startswith(str(expected_value))
|
||||||
|
elif operator == "ew":
|
||||||
|
return str(actual_value).endswith(str(expected_value))
|
||||||
|
elif operator == "gt":
|
||||||
|
return actual_value > expected_value
|
||||||
|
elif operator == "lt":
|
||||||
|
return actual_value < expected_value
|
||||||
|
elif operator == "ge":
|
||||||
|
return actual_value >= expected_value
|
||||||
|
elif operator == "le":
|
||||||
|
return actual_value <= expected_value
|
||||||
|
elif operator == "pr":
|
||||||
|
return actual_value is not None
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _matches_logical(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
|
||||||
|
"""Check if an item matches a logical filter expression"""
|
||||||
|
operator = filter_expr["operator"]
|
||||||
|
|
||||||
|
if operator == "and":
|
||||||
|
left_result = self._matches_filter(item, filter_expr["left"])
|
||||||
|
right_result = self._matches_filter(item, filter_expr["right"])
|
||||||
|
return left_result and right_result
|
||||||
|
elif operator == "or":
|
||||||
|
left_result = self._matches_filter(item, filter_expr["left"])
|
||||||
|
right_result = self._matches_filter(item, filter_expr["right"])
|
||||||
|
return left_result or right_result
|
||||||
|
elif operator == "not":
|
||||||
|
operand_result = self._matches_filter(item, filter_expr["operand"])
|
||||||
|
return not operand_result
|
||||||
|
|
||||||
|
return False
|
||||||
@@ -1101,17 +1101,6 @@
|
|||||||
"returned": "default",
|
"returned": "default",
|
||||||
"uniqueness": "none"
|
"uniqueness": "none"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "password",
|
|
||||||
"type": "string",
|
|
||||||
"multiValued": false,
|
|
||||||
"description": "The User's cleartext password. This attribute is intended to be used as a means to specify an initial\npassword when creating a new User or to reset an existing User's password.",
|
|
||||||
"required": false,
|
|
||||||
"caseExact": false,
|
|
||||||
"mutability": "writeOnly",
|
|
||||||
"returned": "never",
|
|
||||||
"uniqueness": "none"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "emails",
|
"name": "emails",
|
||||||
"type": "complex",
|
"type": "complex",
|
||||||
|
|||||||
@@ -75,7 +75,9 @@ class TestSCIMGroups(APITestCase):
|
|||||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 201)
|
self.assertEqual(response.status_code, 201)
|
||||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
self.assertTrue(
|
||||||
|
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||||
|
)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
Event.objects.filter(
|
Event.objects.filter(
|
||||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||||
@@ -86,6 +88,7 @@ class TestSCIMGroups(APITestCase):
|
|||||||
"""Test group create"""
|
"""Test group create"""
|
||||||
user = create_test_user()
|
user = create_test_user()
|
||||||
ext_id = generate_id()
|
ext_id = generate_id()
|
||||||
|
name = generate_id()
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-groups",
|
"authentik_sources_scim:v2-groups",
|
||||||
@@ -95,7 +98,7 @@ class TestSCIMGroups(APITestCase):
|
|||||||
),
|
),
|
||||||
data=dumps(
|
data=dumps(
|
||||||
{
|
{
|
||||||
"displayName": generate_id(),
|
"displayName": name,
|
||||||
"externalId": ext_id,
|
"externalId": ext_id,
|
||||||
"members": [{"value": str(user.uuid)}],
|
"members": [{"value": str(user.uuid)}],
|
||||||
}
|
}
|
||||||
@@ -104,12 +107,22 @@ class TestSCIMGroups(APITestCase):
|
|||||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 201)
|
self.assertEqual(response.status_code, 201)
|
||||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
connection = SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).first()
|
||||||
|
self.assertIsNotNone(connection)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
Event.objects.filter(
|
Event.objects.filter(
|
||||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||||
).exists()
|
).exists()
|
||||||
)
|
)
|
||||||
|
connection.refresh_from_db()
|
||||||
|
self.assertEqual(
|
||||||
|
connection.attributes,
|
||||||
|
{
|
||||||
|
"displayName": name,
|
||||||
|
"externalId": ext_id,
|
||||||
|
"members": [{"value": str(user.uuid)}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_group_create_members_empty(self):
|
def test_group_create_members_empty(self):
|
||||||
"""Test group create"""
|
"""Test group create"""
|
||||||
@@ -126,7 +139,9 @@ class TestSCIMGroups(APITestCase):
|
|||||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 201)
|
self.assertEqual(response.status_code, 201)
|
||||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
self.assertTrue(
|
||||||
|
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||||
|
)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
Event.objects.filter(
|
Event.objects.filter(
|
||||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||||
@@ -136,7 +151,9 @@ class TestSCIMGroups(APITestCase):
|
|||||||
def test_group_create_duplicate(self):
|
def test_group_create_duplicate(self):
|
||||||
"""Test group create (duplicate)"""
|
"""Test group create (duplicate)"""
|
||||||
group = Group.objects.create(name=generate_id())
|
group = Group.objects.create(name=generate_id())
|
||||||
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
existing = SCIMSourceGroup.objects.create(
|
||||||
|
source=self.source, group=group, external_id=uuid4()
|
||||||
|
)
|
||||||
ext_id = generate_id()
|
ext_id = generate_id()
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse(
|
reverse(
|
||||||
@@ -165,7 +182,9 @@ class TestSCIMGroups(APITestCase):
|
|||||||
def test_group_update(self):
|
def test_group_update(self):
|
||||||
"""Test group update"""
|
"""Test group update"""
|
||||||
group = Group.objects.create(name=generate_id())
|
group = Group.objects.create(name=generate_id())
|
||||||
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
existing = SCIMSourceGroup.objects.create(
|
||||||
|
source=self.source, group=group, external_id=uuid4()
|
||||||
|
)
|
||||||
ext_id = generate_id()
|
ext_id = generate_id()
|
||||||
response = self.client.put(
|
response = self.client.put(
|
||||||
reverse(
|
reverse(
|
||||||
@@ -205,12 +224,49 @@ class TestSCIMGroups(APITestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_group_patch_add(self):
|
def test_group_patch_modify(self):
|
||||||
|
"""Test group patch"""
|
||||||
|
group = Group.objects.create(name=generate_id())
|
||||||
|
connection = SCIMSourceGroup.objects.create(
|
||||||
|
source=self.source,
|
||||||
|
group=group,
|
||||||
|
external_id=uuid4(),
|
||||||
|
attributes={"displayName": group.name, "members": []},
|
||||||
|
)
|
||||||
|
response = self.client.patch(
|
||||||
|
reverse(
|
||||||
|
"authentik_sources_scim:v2-groups",
|
||||||
|
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
|
||||||
|
),
|
||||||
|
data=dumps(
|
||||||
|
{
|
||||||
|
"Operations": [
|
||||||
|
{
|
||||||
|
"op": "Add",
|
||||||
|
"value": {"externalId": "d85051cb-0557-4aa1-98ca-51eabcee4d40"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
content_type=SCIM_CONTENT_TYPE,
|
||||||
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200, response.content)
|
||||||
|
connection = SCIMSourceGroup.objects.filter(id="d85051cb-0557-4aa1-98ca-51eabcee4d40")
|
||||||
|
self.assertIsNotNone(connection)
|
||||||
|
|
||||||
|
def test_group_patch_member_add(self):
|
||||||
"""Test group patch"""
|
"""Test group patch"""
|
||||||
user = create_test_user()
|
user = create_test_user()
|
||||||
|
other_user = create_test_user()
|
||||||
group = Group.objects.create(name=generate_id())
|
group = Group.objects.create(name=generate_id())
|
||||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
group.users.add(other_user)
|
||||||
|
connection = SCIMSourceGroup.objects.create(
|
||||||
|
source=self.source,
|
||||||
|
group=group,
|
||||||
|
external_id=uuid4(),
|
||||||
|
attributes={"displayName": group.name, "members": [{"value": str(other_user.uuid)}]},
|
||||||
|
)
|
||||||
response = self.client.patch(
|
response = self.client.patch(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-groups",
|
"authentik_sources_scim:v2-groups",
|
||||||
@@ -222,7 +278,7 @@ class TestSCIMGroups(APITestCase):
|
|||||||
{
|
{
|
||||||
"op": "Add",
|
"op": "Add",
|
||||||
"path": "members",
|
"path": "members",
|
||||||
"value": {"value": str(user.uuid)},
|
"value": [{"value": str(user.uuid)}],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -230,16 +286,33 @@ class TestSCIMGroups(APITestCase):
|
|||||||
content_type=SCIM_CONTENT_TYPE,
|
content_type=SCIM_CONTENT_TYPE,
|
||||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, second=200)
|
self.assertEqual(response.status_code, 200, response.content)
|
||||||
self.assertTrue(group.users.filter(pk=user.pk).exists())
|
self.assertTrue(group.users.filter(pk=user.pk).exists())
|
||||||
|
self.assertTrue(group.users.filter(pk=other_user.pk).exists())
|
||||||
|
connection.refresh_from_db()
|
||||||
|
self.assertEqual(
|
||||||
|
connection.attributes,
|
||||||
|
{
|
||||||
|
"displayName": group.name,
|
||||||
|
"members": sorted(
|
||||||
|
[{"value": str(other_user.uuid)}, {"value": str(user.uuid)}],
|
||||||
|
key=lambda u: u["value"],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_group_patch_remove(self):
|
def test_group_patch_member_remove(self):
|
||||||
"""Test group patch"""
|
"""Test group patch"""
|
||||||
user = create_test_user()
|
user = create_test_user()
|
||||||
|
|
||||||
group = Group.objects.create(name=generate_id())
|
group = Group.objects.create(name=generate_id())
|
||||||
group.users.add(user)
|
group.users.add(user)
|
||||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
connection = SCIMSourceGroup.objects.create(
|
||||||
|
source=self.source,
|
||||||
|
group=group,
|
||||||
|
external_id=uuid4(),
|
||||||
|
attributes={"displayName": group.name, "members": []},
|
||||||
|
)
|
||||||
response = self.client.patch(
|
response = self.client.patch(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-groups",
|
"authentik_sources_scim:v2-groups",
|
||||||
@@ -251,7 +324,7 @@ class TestSCIMGroups(APITestCase):
|
|||||||
{
|
{
|
||||||
"op": "remove",
|
"op": "remove",
|
||||||
"path": "members",
|
"path": "members",
|
||||||
"value": {"value": str(user.uuid)},
|
"value": [{"value": str(user.uuid)}],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -259,13 +332,21 @@ class TestSCIMGroups(APITestCase):
|
|||||||
content_type=SCIM_CONTENT_TYPE,
|
content_type=SCIM_CONTENT_TYPE,
|
||||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, second=200)
|
self.assertEqual(response.status_code, 200, response.content)
|
||||||
self.assertFalse(group.users.filter(pk=user.pk).exists())
|
self.assertFalse(group.users.filter(pk=user.pk).exists())
|
||||||
|
connection.refresh_from_db()
|
||||||
|
self.assertEqual(
|
||||||
|
connection.attributes,
|
||||||
|
{
|
||||||
|
"displayName": group.name,
|
||||||
|
"members": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_group_delete(self):
|
def test_group_delete(self):
|
||||||
"""Test group delete"""
|
"""Test group delete"""
|
||||||
group = Group.objects.create(name=generate_id())
|
group = Group.objects.create(name=generate_id())
|
||||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
SCIMSourceGroup.objects.create(source=self.source, group=group, external_id=uuid4())
|
||||||
response = self.client.delete(
|
response = self.client.delete(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-groups",
|
"authentik_sources_scim:v2-groups",
|
||||||
|
|||||||
510
authentik/sources/scim/tests/test_lexer.py
Normal file
510
authentik/sources/scim/tests/test_lexer.py
Normal file
@@ -0,0 +1,510 @@
|
|||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from authentik.sources.scim.constants import (
|
||||||
|
SCIM_URN_GROUP,
|
||||||
|
SCIM_URN_SCHEMA,
|
||||||
|
SCIM_URN_USER,
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
)
|
||||||
|
from authentik.sources.scim.patch.lexer import SCIMPathLexer, Token, TokenType
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenType(TestCase):
|
||||||
|
"""Test TokenType enum"""
|
||||||
|
|
||||||
|
def test_token_type_values(self):
|
||||||
|
"""Test that all token types have correct values"""
|
||||||
|
self.assertEqual(TokenType.ATTRIBUTE.value, "ATTRIBUTE")
|
||||||
|
self.assertEqual(TokenType.DOT.value, "DOT")
|
||||||
|
self.assertEqual(TokenType.LBRACKET.value, "LBRACKET")
|
||||||
|
self.assertEqual(TokenType.RBRACKET.value, "RBRACKET")
|
||||||
|
self.assertEqual(TokenType.LPAREN.value, "LPAREN")
|
||||||
|
self.assertEqual(TokenType.RPAREN.value, "RPAREN")
|
||||||
|
self.assertEqual(TokenType.STRING.value, "STRING")
|
||||||
|
self.assertEqual(TokenType.NUMBER.value, "NUMBER")
|
||||||
|
self.assertEqual(TokenType.BOOLEAN.value, "BOOLEAN")
|
||||||
|
self.assertEqual(TokenType.NULL.value, "NULL")
|
||||||
|
self.assertEqual(TokenType.OPERATOR.value, "OPERATOR")
|
||||||
|
self.assertEqual(TokenType.AND.value, "AND")
|
||||||
|
self.assertEqual(TokenType.OR.value, "OR")
|
||||||
|
self.assertEqual(TokenType.NOT.value, "NOT")
|
||||||
|
self.assertEqual(TokenType.EOF.value, "EOF")
|
||||||
|
|
||||||
|
|
||||||
|
class TestToken(TestCase):
|
||||||
|
"""Test Token dataclass"""
|
||||||
|
|
||||||
|
def test_token_creation(self):
|
||||||
|
"""Test token creation with all parameters"""
|
||||||
|
token = Token(TokenType.ATTRIBUTE, "userName", 5)
|
||||||
|
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token.value, "userName")
|
||||||
|
self.assertEqual(token.position, 5)
|
||||||
|
|
||||||
|
def test_token_creation_default_position(self):
|
||||||
|
"""Test token creation with default position"""
|
||||||
|
token = Token(TokenType.DOT, ".")
|
||||||
|
self.assertEqual(token.type, TokenType.DOT)
|
||||||
|
self.assertEqual(token.value, ".")
|
||||||
|
self.assertEqual(token.position, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSCIMPathLexer(TestCase):
|
||||||
|
"""Test SCIMPathLexer class"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures"""
|
||||||
|
self.simple_lexer = SCIMPathLexer("userName")
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test lexer initialization"""
|
||||||
|
lexer = SCIMPathLexer("test")
|
||||||
|
self.assertEqual(lexer.text, "test")
|
||||||
|
self.assertEqual(lexer.pos, 0)
|
||||||
|
self.assertEqual(lexer.current_char, "t")
|
||||||
|
self.assertIn(SCIM_URN_SCHEMA, lexer.schema_urns)
|
||||||
|
self.assertIn(SCIM_URN_GROUP, lexer.schema_urns)
|
||||||
|
self.assertIn(SCIM_URN_USER, lexer.schema_urns)
|
||||||
|
self.assertIn(SCIM_URN_USER_ENTERPRISE, lexer.schema_urns)
|
||||||
|
self.assertEqual(
|
||||||
|
lexer.OPERATORS, ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_init_empty_string(self):
|
||||||
|
"""Test lexer initialization with empty string"""
|
||||||
|
lexer = SCIMPathLexer("")
|
||||||
|
self.assertEqual(lexer.text, "")
|
||||||
|
self.assertEqual(lexer.pos, 0)
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_advance(self):
|
||||||
|
"""Test advance method"""
|
||||||
|
lexer = SCIMPathLexer("abc")
|
||||||
|
self.assertEqual(lexer.current_char, "a")
|
||||||
|
|
||||||
|
lexer.advance()
|
||||||
|
self.assertEqual(lexer.pos, 1)
|
||||||
|
self.assertEqual(lexer.current_char, "b")
|
||||||
|
|
||||||
|
lexer.advance()
|
||||||
|
self.assertEqual(lexer.pos, 2)
|
||||||
|
self.assertEqual(lexer.current_char, "c")
|
||||||
|
|
||||||
|
lexer.advance()
|
||||||
|
self.assertEqual(lexer.pos, 3)
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_skip_whitespace(self):
|
||||||
|
"""Test skip_whitespace method"""
|
||||||
|
lexer = SCIMPathLexer(" \t\n abc")
|
||||||
|
lexer.skip_whitespace()
|
||||||
|
self.assertEqual(lexer.current_char, "a")
|
||||||
|
|
||||||
|
def test_skip_whitespace_only_whitespace(self):
|
||||||
|
"""Test skip_whitespace with only whitespace"""
|
||||||
|
lexer = SCIMPathLexer(" \t\n ")
|
||||||
|
lexer.skip_whitespace()
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_skip_whitespace_no_whitespace(self):
|
||||||
|
"""Test skip_whitespace with no leading whitespace"""
|
||||||
|
lexer = SCIMPathLexer("abc")
|
||||||
|
original_pos = lexer.pos
|
||||||
|
lexer.skip_whitespace()
|
||||||
|
self.assertEqual(lexer.pos, original_pos)
|
||||||
|
self.assertEqual(lexer.current_char, "a")
|
||||||
|
|
||||||
|
def test_read_string_double_quotes(self):
|
||||||
|
"""Test reading double-quoted string"""
|
||||||
|
lexer = SCIMPathLexer('"hello world"')
|
||||||
|
result = lexer.read_string('"')
|
||||||
|
self.assertEqual(result, "hello world")
|
||||||
|
self.assertIsNone(lexer.current_char) # Should be at end
|
||||||
|
|
||||||
|
def test_read_string_single_quotes(self):
|
||||||
|
"""Test reading single-quoted string"""
|
||||||
|
lexer = SCIMPathLexer("'hello world'")
|
||||||
|
result = lexer.read_string("'")
|
||||||
|
self.assertEqual(result, "hello world")
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_read_string_with_escapes(self):
|
||||||
|
"""Test reading string with escape characters"""
|
||||||
|
lexer = SCIMPathLexer('"hello \\"world\\""')
|
||||||
|
result = lexer.read_string('"')
|
||||||
|
self.assertEqual(result, 'hello "world"')
|
||||||
|
|
||||||
|
def test_read_string_with_backslash_at_end(self):
|
||||||
|
"""Test reading string with backslash at end"""
|
||||||
|
lexer = SCIMPathLexer('"hello\\"')
|
||||||
|
result = lexer.read_string('"')
|
||||||
|
self.assertEqual(result, 'hello"')
|
||||||
|
|
||||||
|
def test_read_string_unclosed(self):
|
||||||
|
"""Test reading unclosed string"""
|
||||||
|
lexer = SCIMPathLexer('"hello world')
|
||||||
|
result = lexer.read_string('"')
|
||||||
|
self.assertEqual(result, "hello world")
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_read_string_empty(self):
|
||||||
|
"""Test reading empty string"""
|
||||||
|
lexer = SCIMPathLexer('""')
|
||||||
|
result = lexer.read_string('"')
|
||||||
|
self.assertEqual(result, "")
|
||||||
|
|
||||||
|
def test_read_number_integer(self):
|
||||||
|
"""Test reading integer number"""
|
||||||
|
lexer = SCIMPathLexer("123")
|
||||||
|
result = lexer.read_number()
|
||||||
|
self.assertEqual(result, "123")
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_read_number_float(self):
|
||||||
|
"""Test reading float number"""
|
||||||
|
lexer = SCIMPathLexer("123.456")
|
||||||
|
result = lexer.read_number()
|
||||||
|
self.assertEqual(result, "123.456")
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_read_number_with_multiple_dots(self):
|
||||||
|
"""Test reading number with multiple dots (invalid but handled)"""
|
||||||
|
lexer = SCIMPathLexer("123.456.789")
|
||||||
|
result = lexer.read_number()
|
||||||
|
self.assertEqual(result, "123.456.789")
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_read_number_starting_with_dot(self):
|
||||||
|
"""Test reading number starting with dot"""
|
||||||
|
lexer = SCIMPathLexer(".123")
|
||||||
|
result = lexer.read_number()
|
||||||
|
self.assertEqual(result, ".123")
|
||||||
|
|
||||||
|
def test_read_identifier_simple(self):
|
||||||
|
"""Test reading simple identifier"""
|
||||||
|
lexer = SCIMPathLexer("userName")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "userName")
|
||||||
|
self.assertIsNone(lexer.current_char)
|
||||||
|
|
||||||
|
def test_read_identifier_with_underscore(self):
|
||||||
|
"""Test reading identifier with underscore"""
|
||||||
|
lexer = SCIMPathLexer("user_name")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "user_name")
|
||||||
|
|
||||||
|
def test_read_identifier_with_hyphen(self):
|
||||||
|
"""Test reading identifier with hyphen"""
|
||||||
|
lexer = SCIMPathLexer("user-name")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "user-name")
|
||||||
|
|
||||||
|
def test_read_identifier_with_colon(self):
|
||||||
|
"""Test reading identifier with colon (URN format)"""
|
||||||
|
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
|
||||||
|
|
||||||
|
def test_read_identifier_schema_urn(self):
|
||||||
|
"""Test reading schema URN identifier"""
|
||||||
|
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, SCIM_URN_USER)
|
||||||
|
self.assertEqual(lexer.current_char, ".") # Should stop at dot and set current_char to dot
|
||||||
|
|
||||||
|
def test_read_identifier_with_version_number(self):
|
||||||
|
"""Test reading identifier with version number (dots followed by digits)"""
|
||||||
|
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
|
||||||
|
|
||||||
|
def test_read_identifier_partial_urn_match(self):
|
||||||
|
"""Test reading identifier that partially matches URN"""
|
||||||
|
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:CustomUser")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:CustomUser")
|
||||||
|
|
||||||
|
# Test get_next_token method
|
||||||
|
def test_get_next_token_dot(self):
|
||||||
|
"""Test tokenizing dot"""
|
||||||
|
lexer = SCIMPathLexer(".")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.DOT)
|
||||||
|
self.assertEqual(token.value, ".")
|
||||||
|
|
||||||
|
def test_get_next_token_lbracket(self):
|
||||||
|
"""Test tokenizing left bracket"""
|
||||||
|
lexer = SCIMPathLexer("[")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.LBRACKET)
|
||||||
|
self.assertEqual(token.value, "[")
|
||||||
|
|
||||||
|
def test_get_next_token_rbracket(self):
|
||||||
|
"""Test tokenizing right bracket"""
|
||||||
|
lexer = SCIMPathLexer("]")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.RBRACKET)
|
||||||
|
self.assertEqual(token.value, "]")
|
||||||
|
|
||||||
|
def test_get_next_token_lparen(self):
|
||||||
|
"""Test tokenizing left parenthesis"""
|
||||||
|
lexer = SCIMPathLexer("(")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.LPAREN)
|
||||||
|
self.assertEqual(token.value, "(")
|
||||||
|
|
||||||
|
def test_get_next_token_rparen(self):
|
||||||
|
"""Test tokenizing right parenthesis"""
|
||||||
|
lexer = SCIMPathLexer(")")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.RPAREN)
|
||||||
|
self.assertEqual(token.value, ")")
|
||||||
|
|
||||||
|
def test_get_next_token_string_double_quotes(self):
|
||||||
|
"""Test tokenizing double-quoted string"""
|
||||||
|
lexer = SCIMPathLexer('"test string"')
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.STRING)
|
||||||
|
self.assertEqual(token.value, "test string")
|
||||||
|
|
||||||
|
def test_get_next_token_string_single_quotes(self):
|
||||||
|
"""Test tokenizing single-quoted string"""
|
||||||
|
lexer = SCIMPathLexer("'test string'")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.STRING)
|
||||||
|
self.assertEqual(token.value, "test string")
|
||||||
|
|
||||||
|
def test_get_next_token_number_integer(self):
|
||||||
|
"""Test tokenizing integer"""
|
||||||
|
lexer = SCIMPathLexer("123")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.NUMBER)
|
||||||
|
self.assertEqual(token.value, "123")
|
||||||
|
|
||||||
|
def test_get_next_token_number_float(self):
|
||||||
|
"""Test tokenizing float"""
|
||||||
|
lexer = SCIMPathLexer("123.45")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.NUMBER)
|
||||||
|
self.assertEqual(token.value, "123.45")
|
||||||
|
|
||||||
|
def test_get_next_token_boolean_true(self):
|
||||||
|
"""Test tokenizing boolean true"""
|
||||||
|
lexer = SCIMPathLexer("true")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||||
|
self.assertTrue(token.value)
|
||||||
|
|
||||||
|
def test_get_next_token_boolean_false(self):
|
||||||
|
"""Test tokenizing boolean false"""
|
||||||
|
lexer = SCIMPathLexer("false")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||||
|
self.assertFalse(token.value)
|
||||||
|
|
||||||
|
def test_get_next_token_boolean_case_insensitive(self):
|
||||||
|
"""Test tokenizing boolean with different cases"""
|
||||||
|
for value in ["TRUE", "True", "FALSE", "False"]:
|
||||||
|
with self.subTest(value=value):
|
||||||
|
lexer = SCIMPathLexer(value)
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||||
|
|
||||||
|
def test_get_next_token_null(self):
|
||||||
|
"""Test tokenizing null"""
|
||||||
|
lexer = SCIMPathLexer("null")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.NULL)
|
||||||
|
self.assertIsNone(token.value)
|
||||||
|
|
||||||
|
def test_get_next_token_null_case_insensitive(self):
|
||||||
|
"""Test tokenizing null with different cases"""
|
||||||
|
for value in ["NULL", "Null"]:
|
||||||
|
with self.subTest(value=value):
|
||||||
|
lexer = SCIMPathLexer(value)
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.NULL)
|
||||||
|
|
||||||
|
def test_get_next_token_and(self):
|
||||||
|
"""Test tokenizing AND operator"""
|
||||||
|
lexer = SCIMPathLexer("and")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.AND)
|
||||||
|
self.assertEqual(token.value, "and")
|
||||||
|
|
||||||
|
def test_get_next_token_or(self):
|
||||||
|
"""Test tokenizing OR operator"""
|
||||||
|
lexer = SCIMPathLexer("or")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.OR)
|
||||||
|
self.assertEqual(token.value, "or")
|
||||||
|
|
||||||
|
def test_get_next_token_not(self):
|
||||||
|
"""Test tokenizing NOT operator"""
|
||||||
|
lexer = SCIMPathLexer("not")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.NOT)
|
||||||
|
self.assertEqual(token.value, "not")
|
||||||
|
|
||||||
|
def test_get_next_token_operators(self):
|
||||||
|
"""Test tokenizing all comparison operators"""
|
||||||
|
operators = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||||
|
for op in operators:
|
||||||
|
with self.subTest(operator=op):
|
||||||
|
lexer = SCIMPathLexer(op)
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.OPERATOR)
|
||||||
|
self.assertEqual(token.value, op)
|
||||||
|
|
||||||
|
def test_get_next_token_operators_case_insensitive(self):
|
||||||
|
"""Test tokenizing operators with different cases"""
|
||||||
|
for op in ["EQ", "Eq", "NE", "Ne"]:
|
||||||
|
with self.subTest(operator=op):
|
||||||
|
lexer = SCIMPathLexer(op)
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.OPERATOR)
|
||||||
|
self.assertEqual(token.value, op.lower())
|
||||||
|
|
||||||
|
def test_get_next_token_attribute(self):
|
||||||
|
"""Test tokenizing attribute name"""
|
||||||
|
lexer = SCIMPathLexer("userName")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token.value, "userName")
|
||||||
|
|
||||||
|
def test_get_next_token_attribute_with_underscore(self):
|
||||||
|
"""Test tokenizing attribute name with underscore"""
|
||||||
|
lexer = SCIMPathLexer("_userName")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token.value, "_userName")
|
||||||
|
|
||||||
|
def test_get_next_token_eof(self):
|
||||||
|
"""Test tokenizing end of file"""
|
||||||
|
lexer = SCIMPathLexer("")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.EOF)
|
||||||
|
self.assertEqual(token.value, "")
|
||||||
|
|
||||||
|
def test_get_next_token_with_whitespace(self):
|
||||||
|
"""Test tokenizing with leading whitespace"""
|
||||||
|
lexer = SCIMPathLexer(" userName")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token.value, "userName")
|
||||||
|
|
||||||
|
def test_get_next_token_skip_unknown_characters(self):
|
||||||
|
"""Test that unknown characters are skipped"""
|
||||||
|
lexer = SCIMPathLexer("@#$userName")
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token.value, "userName")
|
||||||
|
|
||||||
|
def test_get_next_token_multiple_tokens(self):
|
||||||
|
"""Test tokenizing multiple tokens in sequence"""
|
||||||
|
lexer = SCIMPathLexer("userName.givenName")
|
||||||
|
|
||||||
|
token1 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token1.value, "userName")
|
||||||
|
|
||||||
|
token2 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token2.type, TokenType.DOT)
|
||||||
|
self.assertEqual(token2.value, ".")
|
||||||
|
|
||||||
|
token3 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token3.value, "givenName")
|
||||||
|
|
||||||
|
token4 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token4.type, TokenType.EOF)
|
||||||
|
|
||||||
|
def test_get_next_token_complex_filter(self):
|
||||||
|
"""Test tokenizing complex filter expression"""
|
||||||
|
lexer = SCIMPathLexer('emails[type eq "work" and primary eq true]')
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
while True:
|
||||||
|
token = lexer.get_next_token()
|
||||||
|
tokens.append(token)
|
||||||
|
if token.type == TokenType.EOF:
|
||||||
|
break
|
||||||
|
|
||||||
|
expected_types = [
|
||||||
|
TokenType.ATTRIBUTE, # emails
|
||||||
|
TokenType.LBRACKET, # [
|
||||||
|
TokenType.ATTRIBUTE, # type
|
||||||
|
TokenType.OPERATOR, # eq
|
||||||
|
TokenType.STRING, # "work"
|
||||||
|
TokenType.AND, # and
|
||||||
|
TokenType.ATTRIBUTE, # primary
|
||||||
|
TokenType.OPERATOR, # eq
|
||||||
|
TokenType.BOOLEAN, # true
|
||||||
|
TokenType.RBRACKET, # ]
|
||||||
|
TokenType.EOF,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(len(tokens), len(expected_types))
|
||||||
|
for token, expected_type in zip(tokens, expected_types, strict=False):
|
||||||
|
self.assertEqual(token.type, expected_type)
|
||||||
|
|
||||||
|
def test_get_next_token_urn_attribute(self):
|
||||||
|
"""Test tokenizing URN-based attribute"""
|
||||||
|
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
|
||||||
|
|
||||||
|
token1 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token1.value, SCIM_URN_USER)
|
||||||
|
|
||||||
|
token2 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token2.type, TokenType.DOT)
|
||||||
|
|
||||||
|
token3 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token3.value, "userName")
|
||||||
|
|
||||||
|
def test_get_next_token_enterprise_urn(self):
|
||||||
|
"""Test tokenizing enterprise URN"""
|
||||||
|
lexer = SCIMPathLexer(f"{SCIM_URN_USER_ENTERPRISE}.manager")
|
||||||
|
|
||||||
|
token1 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||||
|
self.assertEqual(token1.value, SCIM_URN_USER_ENTERPRISE)
|
||||||
|
|
||||||
|
token2 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token2.type, TokenType.DOT)
|
||||||
|
|
||||||
|
def test_lexer_state_after_eof(self):
|
||||||
|
"""Test lexer state after reaching EOF"""
|
||||||
|
lexer = SCIMPathLexer("a")
|
||||||
|
|
||||||
|
# Get first token
|
||||||
|
token1 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||||
|
|
||||||
|
# Get EOF token
|
||||||
|
token2 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token2.type, TokenType.EOF)
|
||||||
|
|
||||||
|
# Should continue returning EOF
|
||||||
|
token3 = lexer.get_next_token()
|
||||||
|
self.assertEqual(token3.type, TokenType.EOF)
|
||||||
|
|
||||||
|
def test_read_identifier_edge_cases(self):
|
||||||
|
"""Test read_identifier with edge cases"""
|
||||||
|
# Test identifier ending with colon
|
||||||
|
lexer = SCIMPathLexer("test:")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "test:")
|
||||||
|
|
||||||
|
# Test identifier with numbers
|
||||||
|
lexer = SCIMPathLexer("test123")
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, "test123")
|
||||||
|
|
||||||
|
def test_complex_urn_parsing(self):
|
||||||
|
"""Test parsing complex URN with version numbers"""
|
||||||
|
urn = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||||
|
lexer = SCIMPathLexer(urn)
|
||||||
|
result = lexer.read_identifier()
|
||||||
|
self.assertEqual(result, urn)
|
||||||
1254
authentik/sources/scim/tests/test_patch.py
Normal file
1254
authentik/sources/scim/tests/test_patch.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -10,6 +10,7 @@ from authentik.core.tests.utils import create_test_user
|
|||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
||||||
|
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||||
from authentik.sources.scim.models import SCIMSource, SCIMSourcePropertyMapping, SCIMSourceUser
|
from authentik.sources.scim.models import SCIMSource, SCIMSourcePropertyMapping, SCIMSourceUser
|
||||||
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
|
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
|
||||||
|
|
||||||
@@ -81,7 +82,9 @@ class TestSCIMUsers(APITestCase):
|
|||||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 201)
|
self.assertEqual(response.status_code, 201)
|
||||||
self.assertTrue(SCIMSourceUser.objects.filter(source=self.source, id=ext_id).exists())
|
self.assertTrue(
|
||||||
|
SCIMSourceUser.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||||
|
)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
Event.objects.filter(
|
Event.objects.filter(
|
||||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||||
@@ -174,14 +177,16 @@ class TestSCIMUsers(APITestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 201)
|
self.assertEqual(response.status_code, 201)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
|
SCIMSourceUser.objects.get(source=self.source, external_id=ext_id).user.attributes[
|
||||||
|
"phone"
|
||||||
|
],
|
||||||
"0123456789",
|
"0123456789",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_user_update(self):
|
def test_user_update(self):
|
||||||
"""Test user update"""
|
"""Test user update"""
|
||||||
user = create_test_user()
|
user = create_test_user()
|
||||||
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
|
existing = SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
|
||||||
ext_id = generate_id()
|
ext_id = generate_id()
|
||||||
response = self.client.put(
|
response = self.client.put(
|
||||||
reverse(
|
reverse(
|
||||||
@@ -209,10 +214,51 @@ class TestSCIMUsers(APITestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
def test_user_update_patch(self):
|
||||||
|
"""Test user update (patch)"""
|
||||||
|
user = create_test_user()
|
||||||
|
existing = SCIMSourceUser.objects.create(
|
||||||
|
source=self.source,
|
||||||
|
user=user,
|
||||||
|
external_id=uuid4(),
|
||||||
|
attributes={
|
||||||
|
"userName": generate_id(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = self.client.patch(
|
||||||
|
reverse(
|
||||||
|
"authentik_sources_scim:v2-users",
|
||||||
|
kwargs={
|
||||||
|
"source_slug": self.source.slug,
|
||||||
|
"user_id": str(user.uuid),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
data=dumps(
|
||||||
|
{
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{
|
||||||
|
"op": "Add",
|
||||||
|
"path": f"{SCIM_URN_USER_ENTERPRISE}:manager",
|
||||||
|
"value": "86b2ed3e-30cd-4881-bb58-c4e910821339",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
content_type=SCIM_CONTENT_TYPE,
|
||||||
|
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
existing.refresh_from_db()
|
||||||
|
self.assertEqual(
|
||||||
|
existing.attributes[SCIM_URN_USER_ENTERPRISE],
|
||||||
|
{"manager": {"value": "86b2ed3e-30cd-4881-bb58-c4e910821339"}},
|
||||||
|
)
|
||||||
|
|
||||||
def test_user_delete(self):
|
def test_user_delete(self):
|
||||||
"""Test user delete"""
|
"""Test user delete"""
|
||||||
user = create_test_user()
|
user = create_test_user()
|
||||||
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
|
SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
|
||||||
response = self.client.delete(
|
response = self.client.delete(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-users",
|
"authentik_sources_scim:v2-users",
|
||||||
|
|||||||
488
authentik/sources/scim/tests/test_users_patch.py
Normal file
488
authentik/sources/scim/tests/test_users_patch.py
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.core.tests.utils import create_test_user
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||||
|
from authentik.sources.scim.models import SCIMSource, SCIMSourceUser
|
||||||
|
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestSCIMUsersPatch(APITestCase):
|
||||||
|
"""Test SCIM User Patch"""
|
||||||
|
|
||||||
|
def test_add(self):
|
||||||
|
req = {
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{"op": "Add", "path": "name.givenName", "value": "aqwer"},
|
||||||
|
{"op": "Add", "path": "name.familyName", "value": "qwerqqqq"},
|
||||||
|
{"op": "Add", "path": "name.formatted", "value": "aqwer qwerqqqq"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"name": {
|
||||||
|
"givenName": "aqwer",
|
||||||
|
"familyName": "qwerqqqq",
|
||||||
|
"formatted": "aqwer qwerqqqq",
|
||||||
|
},
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_add_no_path(self):
|
||||||
|
"""Test add patch with no path set"""
|
||||||
|
req = {
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{"op": "Add", "value": {"externalId": "aqwer"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "aqwer",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_replace(self):
|
||||||
|
req = {
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{"op": "Replace", "path": "name", "value": {"givenName": "aqwer"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"name": {
|
||||||
|
"givenName": "aqwer",
|
||||||
|
},
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_replace_no_path(self):
|
||||||
|
"""Test value replace with no path"""
|
||||||
|
req = {
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{"op": "Replace", "value": {"externalId": "aqwer"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "aqwer",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_remove(self):
|
||||||
|
req = {
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{"op": "Remove", "path": "name", "value": {"givenName": "aqwer"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"name": {
|
||||||
|
"givenName": "aqwer",
|
||||||
|
},
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_large(self):
|
||||||
|
"""Large amount of patch operations"""
|
||||||
|
req = {
|
||||||
|
"Operations": [
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "emails[primary eq true].value",
|
||||||
|
"value": "dandre_kling@wintheiser.info",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "phoneNumbers[primary eq true].value",
|
||||||
|
"value": "72-634-1548",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "phoneNumbers[primary eq true].display",
|
||||||
|
"value": "72-634-1548",
|
||||||
|
},
|
||||||
|
{"op": "replace", "path": "ims[primary eq true].value", "value": "GXSGJKWGHVVS"},
|
||||||
|
{"op": "replace", "path": "ims[primary eq true].display", "value": "IMCHDKUQIPYB"},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "photos[primary eq true].display",
|
||||||
|
"value": "TWAWLHHSUNIV",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "addresses[primary eq true].formatted",
|
||||||
|
"value": "TMINZQAJQDCL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "addresses[primary eq true].streetAddress",
|
||||||
|
"value": "081 Wisoky Key",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "addresses[primary eq true].locality",
|
||||||
|
"value": "DPFASBZRPMDP",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "addresses[primary eq true].region",
|
||||||
|
"value": "WHSTJSPIPTCF",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "addresses[primary eq true].postalCode",
|
||||||
|
"value": "ko28 1qa",
|
||||||
|
},
|
||||||
|
{"op": "replace", "path": "addresses[primary eq true].country", "value": "Taiwan"},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "entitlements[primary eq true].value",
|
||||||
|
"value": "NGBJMUYZVVBX",
|
||||||
|
},
|
||||||
|
{"op": "replace", "path": "roles[primary eq true].value", "value": "XEELVFMMWCVM"},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "x509Certificates[primary eq true].value",
|
||||||
|
"value": "UYISMEDOXUZY",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"value": {
|
||||||
|
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
|
||||||
|
"name.formatted": "Dell",
|
||||||
|
"name.familyName": "Gay",
|
||||||
|
"name.givenName": "Kyler",
|
||||||
|
"name.middleName": "Hannah",
|
||||||
|
"name.honorificPrefix": "Cassie",
|
||||||
|
"name.honorificSuffix": "Yolanda",
|
||||||
|
"displayName": "DPRLIJSFQMTL",
|
||||||
|
"nickName": "BKSPMIRMFBTI",
|
||||||
|
"title": "NBZCOAXVYJUY",
|
||||||
|
"userType": "ZGJMYZRUORZE",
|
||||||
|
"preferredLanguage": "as-IN",
|
||||||
|
"locale": "JLOJHLPWZODG",
|
||||||
|
"timezone": "America/Argentina/Rio_Gallegos",
|
||||||
|
"active": True,
|
||||||
|
f"{SCIM_URN_USER_ENTERPRISE}:employeeNumber": "PDFWRRZBQOHB",
|
||||||
|
f"{SCIM_URN_USER_ENTERPRISE}:costCenter": "HACMZWSEDOTQ",
|
||||||
|
f"{SCIM_URN_USER_ENTERPRISE}:organization": "LXVHJUOLNCLS",
|
||||||
|
f"{SCIM_URN_USER_ENTERPRISE}:division": "JASVTPKPBPMG",
|
||||||
|
f"{SCIM_URN_USER_ENTERPRISE}:department": "GMSBFLMNPABY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"active": True,
|
||||||
|
"addresses": [
|
||||||
|
{
|
||||||
|
"primary": "true",
|
||||||
|
"formatted": "BLJMCNXHYLZK",
|
||||||
|
"streetAddress": "7801 Jacobs Fork",
|
||||||
|
"locality": "HZJBJWFAKXDD",
|
||||||
|
"region": "GJXCXPMIIKWK",
|
||||||
|
"postalCode": "pv82 8ua",
|
||||||
|
"country": "India",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"displayName": "KEFXCHKHAFOT",
|
||||||
|
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
|
||||||
|
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
|
||||||
|
"externalId": "448d2786-7bf6-4e03-a4ef-64cbaf162fa7",
|
||||||
|
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
|
||||||
|
"locale": "PJNYJHWJILTI",
|
||||||
|
"name": {
|
||||||
|
"formatted": "Ladarius",
|
||||||
|
"familyName": "Manley",
|
||||||
|
"givenName": "Mazie",
|
||||||
|
"middleName": "Vernon",
|
||||||
|
"honorificPrefix": "Melyssa",
|
||||||
|
"honorificSuffix": "Demarcus",
|
||||||
|
},
|
||||||
|
"nickName": "HTPKOXMWZKHL",
|
||||||
|
"phoneNumbers": [
|
||||||
|
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
|
||||||
|
],
|
||||||
|
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
|
||||||
|
"preferredLanguage": "wae",
|
||||||
|
"profileUrl": "HPSEOIPXMGOH",
|
||||||
|
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"timezone": "America/Indiana/Petersburg",
|
||||||
|
"title": "EJWFXLHNHMCD",
|
||||||
|
SCIM_URN_USER_ENTERPRISE: {
|
||||||
|
"employeeNumber": "XHDMEJUURJNR",
|
||||||
|
"costCenter": "RXUYBXOTRCZH",
|
||||||
|
"organization": "CEXWXMBRYAHN",
|
||||||
|
"division": "XMPFMDCLRKCW",
|
||||||
|
"department": "BKMNJVMCJUYS",
|
||||||
|
"manager": "PNGSGXLYVWMV",
|
||||||
|
},
|
||||||
|
"userName": "imelda.auer@kshlerin.co.uk",
|
||||||
|
"userType": "PZFXORVSUAPU",
|
||||||
|
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"active": True,
|
||||||
|
"addresses": [
|
||||||
|
{
|
||||||
|
"primary": "true",
|
||||||
|
"formatted": "BLJMCNXHYLZK",
|
||||||
|
"streetAddress": "7801 Jacobs Fork",
|
||||||
|
"locality": "HZJBJWFAKXDD",
|
||||||
|
"region": "GJXCXPMIIKWK",
|
||||||
|
"postalCode": "pv82 8ua",
|
||||||
|
"country": "India",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"displayName": "DPRLIJSFQMTL",
|
||||||
|
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
|
||||||
|
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
|
||||||
|
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
|
||||||
|
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
|
||||||
|
"locale": "JLOJHLPWZODG",
|
||||||
|
"name": {
|
||||||
|
"formatted": "Dell",
|
||||||
|
"familyName": "Gay",
|
||||||
|
"givenName": "Kyler",
|
||||||
|
"middleName": "Hannah",
|
||||||
|
"honorificPrefix": "Cassie",
|
||||||
|
"honorificSuffix": "Yolanda",
|
||||||
|
},
|
||||||
|
"nickName": "BKSPMIRMFBTI",
|
||||||
|
"phoneNumbers": [
|
||||||
|
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
|
||||||
|
],
|
||||||
|
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
|
||||||
|
"preferredLanguage": "as-IN",
|
||||||
|
"profileUrl": "HPSEOIPXMGOH",
|
||||||
|
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"timezone": "America/Argentina/Rio_Gallegos",
|
||||||
|
"title": "NBZCOAXVYJUY",
|
||||||
|
SCIM_URN_USER_ENTERPRISE: {
|
||||||
|
"employeeNumber": "PDFWRRZBQOHB",
|
||||||
|
"costCenter": "HACMZWSEDOTQ",
|
||||||
|
"organization": "LXVHJUOLNCLS",
|
||||||
|
"division": "JASVTPKPBPMG",
|
||||||
|
"department": "GMSBFLMNPABY",
|
||||||
|
"manager": "PNGSGXLYVWMV",
|
||||||
|
},
|
||||||
|
"userName": "imelda.auer@kshlerin.co.uk",
|
||||||
|
"userType": "ZGJMYZRUORZE",
|
||||||
|
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_schema_urn_manager(self):
|
||||||
|
req = {
|
||||||
|
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||||
|
"Operations": [
|
||||||
|
{
|
||||||
|
"op": "Add",
|
||||||
|
"value": {
|
||||||
|
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager": "foo"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
user = create_test_user()
|
||||||
|
source = SCIMSource.objects.create(slug=generate_id())
|
||||||
|
connection = SCIMSourceUser.objects.create(
|
||||||
|
user=user,
|
||||||
|
id=generate_id(),
|
||||||
|
source=source,
|
||||||
|
attributes={
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||||
|
self.assertEqual(
|
||||||
|
updated,
|
||||||
|
{
|
||||||
|
"meta": {"resourceType": "User"},
|
||||||
|
"active": True,
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
SCIM_URN_USER_ENTERPRISE,
|
||||||
|
],
|
||||||
|
"userName": "test@t.goauthentik.io",
|
||||||
|
"externalId": "test",
|
||||||
|
"displayName": "Test MS",
|
||||||
|
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": {
|
||||||
|
"manager": {"value": "foo"}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""SCIM Utils"""
|
"""SCIM Utils"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.paginator import Page, Paginator
|
from django.core.paginator import Page, Paginator
|
||||||
@@ -21,6 +22,7 @@ from authentik.core.sources.mapper import SourceMapper
|
|||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||||
from authentik.sources.scim.models import SCIMSource
|
from authentik.sources.scim.models import SCIMSource
|
||||||
from authentik.sources.scim.views.v2.auth import SCIMTokenAuth
|
from authentik.sources.scim.views.v2.auth import SCIMTokenAuth
|
||||||
|
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
|
||||||
|
|
||||||
SCIM_CONTENT_TYPE = "application/scim+json"
|
SCIM_CONTENT_TYPE = "application/scim+json"
|
||||||
|
|
||||||
@@ -54,6 +56,13 @@ class SCIMView(APIView):
|
|||||||
def get_authenticators(self):
|
def get_authenticators(self):
|
||||||
return [SCIMTokenAuth(self)]
|
return [SCIMTokenAuth(self)]
|
||||||
|
|
||||||
|
def remove_excluded_attributes(self, data: dict):
|
||||||
|
"""Remove attributes specified in excludedAttributes"""
|
||||||
|
excluded: str = self.request.query_params.get("excludedAttributes", "")
|
||||||
|
for key in excluded.split(","):
|
||||||
|
data.pop(key.strip(), None)
|
||||||
|
return data
|
||||||
|
|
||||||
def filter_parse(self, request: Request):
|
def filter_parse(self, request: Request):
|
||||||
"""Parse the path of a Patch Operation"""
|
"""Parse the path of a Patch Operation"""
|
||||||
path = request.query_params.get("filter")
|
path = request.query_params.get("filter")
|
||||||
@@ -103,6 +112,12 @@ class SCIMObjectView(SCIMView):
|
|||||||
# a source attribute before
|
# a source attribute before
|
||||||
self.mapper = SourceMapper(self.source)
|
self.mapper = SourceMapper(self.source)
|
||||||
self.manager = self.mapper.get_manager(self.model, ["data"])
|
self.manager = self.mapper.get_manager(self.model, ["data"])
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key.endswith("_id"):
|
||||||
|
try:
|
||||||
|
UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
raise SCIMNotFoundError("Invalid ID") from None
|
||||||
|
|
||||||
def build_object_properties(self, data: dict[str, Any]) -> dict[str, Any | dict[str, Any]]:
|
def build_object_properties(self, data: dict[str, Any]) -> dict[str, Any | dict[str, Any]]:
|
||||||
return self.mapper.build_object_properties(
|
return self.mapper.build_object_properties(
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from authentik.core.models import Group, User
|
|||||||
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
|
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
|
||||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
|
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
|
||||||
from authentik.sources.scim.models import SCIMSourceGroup
|
from authentik.sources.scim.models import SCIMSourceGroup
|
||||||
|
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
||||||
from authentik.sources.scim.views.v2.exceptions import (
|
from authentik.sources.scim.views.v2.exceptions import (
|
||||||
SCIMConflictError,
|
SCIMConflictError,
|
||||||
@@ -35,11 +36,12 @@ class GroupsView(SCIMObjectView):
|
|||||||
payload = SCIMGroupModel(
|
payload = SCIMGroupModel(
|
||||||
schemas=[SCIM_GROUP_SCHEMA],
|
schemas=[SCIM_GROUP_SCHEMA],
|
||||||
id=str(scim_group.group.pk),
|
id=str(scim_group.group.pk),
|
||||||
externalId=scim_group.id,
|
externalId=scim_group.external_id,
|
||||||
displayName=scim_group.group.name,
|
displayName=scim_group.group.name,
|
||||||
members=[],
|
members=[],
|
||||||
meta={
|
meta={
|
||||||
"resourceType": "Group",
|
"resourceType": "Group",
|
||||||
|
"lastModified": scim_group.last_update,
|
||||||
"location": self.request.build_absolute_uri(
|
"location": self.request.build_absolute_uri(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-groups",
|
"authentik_sources_scim:v2-groups",
|
||||||
@@ -54,7 +56,11 @@ class GroupsView(SCIMObjectView):
|
|||||||
for member in scim_group.group.users.order_by("pk"):
|
for member in scim_group.group.users.order_by("pk"):
|
||||||
member: User
|
member: User
|
||||||
payload.members.append(GroupMember(value=str(member.uuid)))
|
payload.members.append(GroupMember(value=str(member.uuid)))
|
||||||
return payload.model_dump(mode="json", exclude_unset=True)
|
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||||
|
final_payload.update(scim_group.attributes)
|
||||||
|
return self.remove_excluded_attributes(
|
||||||
|
SCIMGroupModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
|
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
|
||||||
"""List Group handler"""
|
"""List Group handler"""
|
||||||
@@ -81,7 +87,7 @@ class GroupsView(SCIMObjectView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@atomic
|
@atomic
|
||||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
|
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict, apply_members=True):
|
||||||
"""Partial update a group"""
|
"""Partial update a group"""
|
||||||
properties = self.build_object_properties(data)
|
properties = self.build_object_properties(data)
|
||||||
|
|
||||||
@@ -94,7 +100,7 @@ class GroupsView(SCIMObjectView):
|
|||||||
|
|
||||||
group.update_attributes(properties)
|
group.update_attributes(properties)
|
||||||
|
|
||||||
if "members" in data:
|
if "members" in data and apply_members:
|
||||||
query = Q()
|
query = Q()
|
||||||
for _member in data.get("members", []):
|
for _member in data.get("members", []):
|
||||||
try:
|
try:
|
||||||
@@ -105,14 +111,18 @@ class GroupsView(SCIMObjectView):
|
|||||||
query |= Q(uuid=member.value)
|
query |= Q(uuid=member.value)
|
||||||
if query:
|
if query:
|
||||||
group.users.set(User.objects.filter(query))
|
group.users.set(User.objects.filter(query))
|
||||||
|
data["members"] = self._convert_members(group)
|
||||||
if not connection:
|
if not connection:
|
||||||
connection, _ = SCIMSourceGroup.objects.get_or_create(
|
connection, _ = SCIMSourceGroup.objects.update_or_create(
|
||||||
|
external_id=data.get("externalId") or str(uuid4()),
|
||||||
source=self.source,
|
source=self.source,
|
||||||
group=group,
|
group=group,
|
||||||
attributes=data,
|
defaults={
|
||||||
id=data.get("externalId") or str(uuid4()),
|
"attributes": data,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
connection.external_id = data.get("externalId", connection.external_id)
|
||||||
connection.attributes = data
|
connection.attributes = data
|
||||||
connection.save()
|
connection.save()
|
||||||
return connection
|
return connection
|
||||||
@@ -139,6 +149,12 @@ class GroupsView(SCIMObjectView):
|
|||||||
connection = self.update_group(connection, request.data)
|
connection = self.update_group(connection, request.data)
|
||||||
return Response(self.group_to_scim(connection), status=200)
|
return Response(self.group_to_scim(connection), status=200)
|
||||||
|
|
||||||
|
def _convert_members(self, group: Group):
|
||||||
|
users = []
|
||||||
|
for user in group.users.all().order_by("uuid"):
|
||||||
|
users.append({"value": str(user.uuid)})
|
||||||
|
return sorted(users, key=lambda u: u["value"])
|
||||||
|
|
||||||
@atomic
|
@atomic
|
||||||
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
|
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
|
||||||
"""Patch group handler"""
|
"""Patch group handler"""
|
||||||
@@ -171,6 +187,13 @@ class GroupsView(SCIMObjectView):
|
|||||||
query |= Q(uuid=member["value"])
|
query |= Q(uuid=member["value"])
|
||||||
if query:
|
if query:
|
||||||
connection.group.users.remove(*User.objects.filter(query))
|
connection.group.users.remove(*User.objects.filter(query))
|
||||||
|
patcher = SCIMPatchProcessor()
|
||||||
|
patched_data = patcher.apply_patches(
|
||||||
|
connection.attributes, request.data.get("Operations", [])
|
||||||
|
)
|
||||||
|
patched_data["members"] = self._convert_members(connection.group)
|
||||||
|
if patched_data != connection.attributes:
|
||||||
|
self.update_group(connection, patched_data, apply_members=False)
|
||||||
return Response(self.group_to_scim(connection), status=200)
|
return Response(self.group_to_scim(connection), status=200)
|
||||||
|
|
||||||
@atomic
|
@atomic
|
||||||
|
|||||||
@@ -33,9 +33,7 @@ class ServiceProviderConfigView(SCIMView):
|
|||||||
{
|
{
|
||||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
|
||||||
"authenticationSchemes": auth_schemas,
|
"authenticationSchemes": auth_schemas,
|
||||||
# We only support patch for groups currently, so don't broadly advertise it.
|
"patch": {"supported": True},
|
||||||
# Implementations that require Group patch will use it regardless of this flag.
|
|
||||||
"patch": {"supported": False},
|
|
||||||
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
|
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
|
||||||
"filter": {
|
"filter": {
|
||||||
"supported": True,
|
"supported": True,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from authentik.core.models import User
|
|||||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
|
from authentik.providers.scim.clients.schema import User as SCIMUserModel
|
||||||
from authentik.sources.scim.models import SCIMSourceUser
|
from authentik.sources.scim.models import SCIMSourceUser
|
||||||
|
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
||||||
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
|
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ class UsersView(SCIMObjectView):
|
|||||||
payload = SCIMUserModel(
|
payload = SCIMUserModel(
|
||||||
schemas=[SCIM_USER_SCHEMA],
|
schemas=[SCIM_USER_SCHEMA],
|
||||||
id=str(scim_user.user.uuid),
|
id=str(scim_user.user.uuid),
|
||||||
externalId=scim_user.id,
|
externalId=scim_user.external_id,
|
||||||
userName=scim_user.user.username,
|
userName=scim_user.user.username,
|
||||||
name=Name(
|
name=Name(
|
||||||
formatted=scim_user.user.name,
|
formatted=scim_user.user.name,
|
||||||
@@ -44,8 +45,7 @@ class UsersView(SCIMObjectView):
|
|||||||
meta={
|
meta={
|
||||||
"resourceType": "User",
|
"resourceType": "User",
|
||||||
"created": scim_user.user.date_joined,
|
"created": scim_user.user.date_joined,
|
||||||
# TODO: use events to find last edit?
|
"lastModified": scim_user.last_update,
|
||||||
"lastModified": scim_user.user.date_joined,
|
|
||||||
"location": self.request.build_absolute_uri(
|
"location": self.request.build_absolute_uri(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_sources_scim:v2-users",
|
"authentik_sources_scim:v2-users",
|
||||||
@@ -59,7 +59,9 @@ class UsersView(SCIMObjectView):
|
|||||||
)
|
)
|
||||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||||
final_payload.update(scim_user.attributes)
|
final_payload.update(scim_user.attributes)
|
||||||
return final_payload
|
return self.remove_excluded_attributes(
|
||||||
|
SCIMUserModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, request: Request, user_id: str | None = None, **kwargs) -> Response:
|
def get(self, request: Request, user_id: str | None = None, **kwargs) -> Response:
|
||||||
"""List User handler"""
|
"""List User handler"""
|
||||||
@@ -101,13 +103,16 @@ class UsersView(SCIMObjectView):
|
|||||||
user.update_attributes(properties)
|
user.update_attributes(properties)
|
||||||
|
|
||||||
if not connection:
|
if not connection:
|
||||||
connection, _ = SCIMSourceUser.objects.get_or_create(
|
connection, _ = SCIMSourceUser.objects.update_or_create(
|
||||||
|
external_id=data.get("externalId") or str(uuid4()),
|
||||||
source=self.source,
|
source=self.source,
|
||||||
user=user,
|
user=user,
|
||||||
attributes=data,
|
defaults={
|
||||||
id=data.get("externalId") or str(uuid4()),
|
"attributes": data,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
connection.external_id = data.get("externalId", connection.external_id)
|
||||||
connection.attributes = data
|
connection.attributes = data
|
||||||
connection.save()
|
connection.save()
|
||||||
return connection
|
return connection
|
||||||
@@ -127,6 +132,18 @@ class UsersView(SCIMObjectView):
|
|||||||
connection = self.update_user(None, request.data)
|
connection = self.update_user(None, request.data)
|
||||||
return Response(self.user_to_scim(connection), status=201)
|
return Response(self.user_to_scim(connection), status=201)
|
||||||
|
|
||||||
|
def patch(self, request: Request, user_id: str, **kwargs):
|
||||||
|
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
||||||
|
if not connection:
|
||||||
|
raise SCIMNotFoundError("User not found.")
|
||||||
|
patcher = SCIMPatchProcessor()
|
||||||
|
patched_data = patcher.apply_patches(
|
||||||
|
connection.attributes, request.data.get("Operations", [])
|
||||||
|
)
|
||||||
|
if patched_data != connection.attributes:
|
||||||
|
self.update_user(connection, patched_data)
|
||||||
|
return Response(self.user_to_scim(connection), status=200)
|
||||||
|
|
||||||
def put(self, request: Request, user_id: str, **kwargs) -> Response:
|
def put(self, request: Request, user_id: str, **kwargs) -> Response:
|
||||||
"""Update user handler"""
|
"""Update user handler"""
|
||||||
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from authentik.flows.exceptions import StageInvalidException
|
|||||||
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.models import SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.lib.utils.time import timedelta_string_validator
|
from authentik.lib.utils.time import timedelta_string_validator
|
||||||
from authentik.stages.authenticator.models import SideChannelDevice
|
from authentik.stages.authenticator.models import SideChannelDevice
|
||||||
from authentik.stages.email.utils import TemplateEmailMessage
|
from authentik.stages.email.utils import TemplateEmailMessage
|
||||||
@@ -160,9 +159,8 @@ class EmailDevice(SerializerModel, SideChannelDevice):
|
|||||||
Event.new(
|
Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message=_("Exception occurred while rendering E-mail template"),
|
message=_("Exception occurred while rendering E-mail template"),
|
||||||
error=exception_to_string(exc),
|
|
||||||
template=stage.template,
|
template=stage.template,
|
||||||
).from_http(self.request)
|
).with_exception(exc).from_http(self.request)
|
||||||
raise StageInvalidException from exc
|
raise StageInvalidException from exc
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from authentik.flows.challenge import (
|
|||||||
from authentik.flows.exceptions import StageInvalidException
|
from authentik.flows.exceptions import StageInvalidException
|
||||||
from authentik.flows.stage import ChallengeStageView
|
from authentik.flows.stage import ChallengeStageView
|
||||||
from authentik.lib.utils.email import mask_email
|
from authentik.lib.utils.email import mask_email
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.stages.authenticator_email.models import (
|
from authentik.stages.authenticator_email.models import (
|
||||||
AuthenticatorEmailStage,
|
AuthenticatorEmailStage,
|
||||||
@@ -100,9 +99,8 @@ class AuthenticatorEmailStageView(ChallengeStageView):
|
|||||||
Event.new(
|
Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message=_("Exception occurred while rendering E-mail template"),
|
message=_("Exception occurred while rendering E-mail template"),
|
||||||
error=exception_to_string(exc),
|
|
||||||
template=stage.template,
|
template=stage.template,
|
||||||
).from_http(self.request)
|
).with_exception(exc).from_http(self.request)
|
||||||
raise StageInvalidException from exc
|
raise StageInvalidException from exc
|
||||||
|
|
||||||
def _has_email(self) -> str | None:
|
def _has_email(self) -> str | None:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from authentik.events.models import Event, EventAction, NotificationWebhookMappi
|
|||||||
from authentik.events.utils import sanitize_item
|
from authentik.events.utils import sanitize_item
|
||||||
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
||||||
from authentik.lib.models import SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.stages.authenticator.models import SideChannelDevice
|
from authentik.stages.authenticator.models import SideChannelDevice
|
||||||
|
|
||||||
@@ -142,10 +141,9 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
|||||||
Event.new(
|
Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message="Error sending SMS",
|
message="Error sending SMS",
|
||||||
exc=exception_to_string(exc),
|
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
body=response.text,
|
body=response.text,
|
||||||
).set_user(device.user).save()
|
).with_exception(exc).set_user(device.user).save()
|
||||||
if response.status_code >= HttpResponseBadRequest.status_code:
|
if response.status_code >= HttpResponseBadRequest.status_code:
|
||||||
raise ValidationError(response.text) from None
|
raise ValidationError(response.text) from None
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from django.http.response import Http404
|
|||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.utils.translation import gettext as __
|
from django.utils.translation import gettext as __
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.fields import CharField, DateTimeField
|
from rest_framework.fields import CharField, ChoiceField, DateTimeField
|
||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
from webauthn import options_to_json
|
from webauthn import options_to_json
|
||||||
@@ -18,7 +18,7 @@ from webauthn.authentication.verify_authentication_response import verify_authen
|
|||||||
from webauthn.helpers import parse_authentication_credential_json
|
from webauthn.helpers import parse_authentication_credential_json
|
||||||
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
|
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
|
||||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse, InvalidJSONStructure
|
from webauthn.helpers.exceptions import InvalidAuthenticationResponse, InvalidJSONStructure
|
||||||
from webauthn.helpers.structs import UserVerificationRequirement
|
from webauthn.helpers.structs import PublicKeyCredentialType, UserVerificationRequirement
|
||||||
|
|
||||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
from authentik.core.models import Application, User
|
from authentik.core.models import Application, User
|
||||||
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
|||||||
class DeviceChallenge(PassiveSerializer):
|
class DeviceChallenge(PassiveSerializer):
|
||||||
"""Single device challenge"""
|
"""Single device challenge"""
|
||||||
|
|
||||||
device_class = CharField()
|
device_class = ChoiceField(choices=DeviceClasses.choices)
|
||||||
device_uid = CharField()
|
device_uid = CharField()
|
||||||
challenge = JSONDictField()
|
challenge = JSONDictField()
|
||||||
last_used = DateTimeField(allow_null=True)
|
last_used = DateTimeField(allow_null=True)
|
||||||
@@ -157,6 +157,12 @@ def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -
|
|||||||
request = stage_view.request
|
request = stage_view.request
|
||||||
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
|
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
|
||||||
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
|
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
|
||||||
|
if "MinuteMaid" in request.META.get("HTTP_USER_AGENT", ""):
|
||||||
|
# Workaround for Android sign-in, when signing into Google Workspace on android while
|
||||||
|
# adding the account to the system (not in Chrome), for some reason `type` is not set
|
||||||
|
# so in that case we fall back to `public-key`
|
||||||
|
# since that's the only option we support anyways
|
||||||
|
data.setdefault("type", PublicKeyCredentialType.PUBLIC_KEY)
|
||||||
try:
|
try:
|
||||||
credential = parse_authentication_credential_json(data)
|
credential = parse_authentication_credential_json(data)
|
||||||
except InvalidJSONStructure as exc:
|
except InvalidJSONStructure as exc:
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
|
|||||||
{
|
{
|
||||||
"auth_method": "auth_mfa",
|
"auth_method": "auth_mfa",
|
||||||
"auth_method_args": {
|
"auth_method_args": {
|
||||||
|
"known_device": False,
|
||||||
"mfa_devices": [
|
"mfa_devices": [
|
||||||
{
|
{
|
||||||
"app": "authentik_stages_authenticator_duo",
|
"app": "authentik_stages_authenticator_duo",
|
||||||
@@ -180,7 +181,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
|
|||||||
"name": "",
|
"name": "",
|
||||||
"pk": duo_device.pk,
|
"pk": duo_device.pk,
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
},
|
},
|
||||||
"http_request": {
|
"http_request": {
|
||||||
"args": {},
|
"args": {},
|
||||||
|
|||||||
@@ -153,13 +153,13 @@ class AuthenticatorValidateStageTests(FlowTestCase):
|
|||||||
plan.append_stage(stage)
|
plan.append_stage(stage)
|
||||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||||
{
|
{
|
||||||
"device_class": "static",
|
"device_class": DeviceClasses.STATIC,
|
||||||
"device_uid": "1",
|
"device_uid": "1",
|
||||||
"challenge": {},
|
"challenge": {},
|
||||||
"last_used": now(),
|
"last_used": now(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"device_class": "totp",
|
"device_class": DeviceClasses.TOTP,
|
||||||
"device_uid": "2",
|
"device_uid": "2",
|
||||||
"challenge": {},
|
"challenge": {},
|
||||||
"last_used": now(),
|
"last_used": now(),
|
||||||
@@ -172,7 +172,7 @@ class AuthenticatorValidateStageTests(FlowTestCase):
|
|||||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
data={
|
data={
|
||||||
"selected_challenge": {
|
"selected_challenge": {
|
||||||
"device_class": "baz",
|
"device_class": DeviceClasses.WEBAUTHN,
|
||||||
"device_uid": "quox",
|
"device_uid": "quox",
|
||||||
"challenge": {},
|
"challenge": {},
|
||||||
"last_used": None,
|
"last_used": None,
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
|||||||
session = self.client.session
|
session = self.client.session
|
||||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||||
plan.append_stage(stage)
|
plan.append_stage(stage)
|
||||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||||
session[SESSION_KEY_PLAN] = plan
|
session[SESSION_KEY_PLAN] = plan
|
||||||
session.save()
|
session.save()
|
||||||
@@ -282,7 +282,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
|||||||
session = self.client.session
|
session = self.client.session
|
||||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||||
plan.append_stage(stage)
|
plan.append_stage(stage)
|
||||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||||
{
|
{
|
||||||
@@ -359,7 +359,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
|||||||
session = self.client.session
|
session = self.client.session
|
||||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||||
plan.append_stage(stage)
|
plan.append_stage(stage)
|
||||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||||
{
|
{
|
||||||
@@ -441,7 +441,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
|||||||
session = self.client.session
|
session = self.client.session
|
||||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||||
plan.append_stage(stage)
|
plan.append_stage(stage)
|
||||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||||
{
|
{
|
||||||
"device_class": device.__class__.__name__.lower().replace("device", ""),
|
"device_class": device.__class__.__name__.lower().replace("device", ""),
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -21,7 +21,6 @@ from authentik.flows.models import FlowDesignation, FlowToken
|
|||||||
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
|
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
|
||||||
from authentik.flows.stage import ChallengeStageView
|
from authentik.flows.stage import ChallengeStageView
|
||||||
from authentik.flows.views.executor import QS_KEY_TOKEN, QS_QUERY
|
from authentik.flows.views.executor import QS_KEY_TOKEN, QS_QUERY
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.stages.email.flow import pickle_flow_token_for_email
|
from authentik.stages.email.flow import pickle_flow_token_for_email
|
||||||
from authentik.stages.email.models import EmailStage
|
from authentik.stages.email.models import EmailStage
|
||||||
@@ -129,9 +128,8 @@ class EmailStageView(ChallengeStageView):
|
|||||||
Event.new(
|
Event.new(
|
||||||
EventAction.CONFIGURATION_ERROR,
|
EventAction.CONFIGURATION_ERROR,
|
||||||
message=_("Exception occurred while rendering E-mail template"),
|
message=_("Exception occurred while rendering E-mail template"),
|
||||||
error=exception_to_string(exc),
|
|
||||||
template=current_stage.template,
|
template=current_stage.template,
|
||||||
).from_http(self.request)
|
).with_exception(exc).from_http(self.request)
|
||||||
raise StageInvalidException from exc
|
raise StageInvalidException from exc
|
||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
@@ -145,7 +143,7 @@ class EmailStageView(ChallengeStageView):
|
|||||||
messages.success(request, _("Successfully verified Email."))
|
messages.success(request, _("Successfully verified Email."))
|
||||||
if self.executor.current_stage.activate_user_on_success:
|
if self.executor.current_stage.activate_user_on_success:
|
||||||
user.is_active = True
|
user.is_active = True
|
||||||
user.save()
|
user.save(update_fields=["is_active"])
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
|
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
|
||||||
self.logger.debug("No pending user")
|
self.logger.debug("No pending user")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Prompt Stage Logic"""
|
"""Prompt Stage Logic"""
|
||||||
|
|
||||||
from collections.abc import Callable, Iterator
|
from collections.abc import Callable
|
||||||
from email.policy import Policy
|
from email.policy import Policy
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -190,10 +190,11 @@ class ListPolicyEngine(PolicyEngine):
|
|||||||
self.__list = policies
|
self.__list = policies
|
||||||
self.use_cache = False
|
self.use_cache = False
|
||||||
|
|
||||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
def bindings(self):
|
||||||
for policy in self.__list:
|
for idx, policy in enumerate(self.__list):
|
||||||
yield PolicyBinding(
|
yield PolicyBinding(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
order=idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class TestPromptStage(FlowTestCase):
|
|||||||
"""Test challenge_response validation"""
|
"""Test challenge_response validation"""
|
||||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||||
expr = "False"
|
expr = "False"
|
||||||
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
expr_policy = ExpressionPolicy.objects.create(name=generate_id(), expression=expr)
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
@@ -222,6 +222,18 @@ class TestPromptStage(FlowTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), False)
|
self.assertEqual(challenge_response.is_valid(), False)
|
||||||
|
|
||||||
|
def test_invalid_challenge_multiple(self):
|
||||||
|
"""Test challenge_response validation (multiple policies)"""
|
||||||
|
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||||
|
expr_policy1 = ExpressionPolicy.objects.create(name=generate_id(), expression="False")
|
||||||
|
expr_policy2 = ExpressionPolicy.objects.create(name=generate_id(), expression="False")
|
||||||
|
self.stage.validation_policies.set([expr_policy1, expr_policy2])
|
||||||
|
self.stage.save()
|
||||||
|
challenge_response = PromptChallengeResponse(
|
||||||
|
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||||
|
)
|
||||||
|
self.assertEqual(challenge_response.is_valid(), False)
|
||||||
|
|
||||||
def test_valid_challenge_request(self):
|
def test_valid_challenge_request(self):
|
||||||
"""Test a request with valid challenge_response data"""
|
"""Test a request with valid challenge_response data"""
|
||||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||||
@@ -234,7 +246,7 @@ class TestPromptStage(FlowTestCase):
|
|||||||
"return request.context['prompt_data']['password_prompt'] "
|
"return request.context['prompt_data']['password_prompt'] "
|
||||||
"== request.context['prompt_data']['password2_prompt']"
|
"== request.context['prompt_data']['password2_prompt']"
|
||||||
)
|
)
|
||||||
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
expr_policy = ExpressionPolicy.objects.create(name=generate_id(), expression=expr)
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class UserLoginStageSerializer(StageSerializer):
|
|||||||
"remember_me_offset",
|
"remember_me_offset",
|
||||||
"network_binding",
|
"network_binding",
|
||||||
"geoip_binding",
|
"geoip_binding",
|
||||||
|
"remember_device",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from django.contrib.auth.views import redirect_to_login
|
|||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.core.models import Session
|
||||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
@@ -89,7 +90,7 @@ class BoundSessionMiddleware(SessionMiddleware):
|
|||||||
|
|
||||||
def recheck_session(self, request: HttpRequest):
|
def recheck_session(self, request: HttpRequest):
|
||||||
"""Check if a session is still valid with a changed IP"""
|
"""Check if a session is still valid with a changed IP"""
|
||||||
last_ip = request.session.get(request.session.model.Keys.LAST_IP)
|
last_ip = request.session.get(Session.Keys.LAST_IP)
|
||||||
new_ip = ClientIPMiddleware.get_client_ip(request)
|
new_ip = ClientIPMiddleware.get_client_ip(request)
|
||||||
# Check changed IP
|
# Check changed IP
|
||||||
if new_ip == last_ip:
|
if new_ip == last_ip:
|
||||||
@@ -109,7 +110,7 @@ class BoundSessionMiddleware(SessionMiddleware):
|
|||||||
if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
|
if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
|
||||||
# Only set the last IP in the session if there's a binding specified
|
# Only set the last IP in the session if there's a binding specified
|
||||||
# (== basically requires the user to be logged in)
|
# (== basically requires the user to be logged in)
|
||||||
request.session[request.session.model.Keys.LAST_IP] = new_ip
|
request.session[Session.Keys.LAST_IP] = new_ip
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):
|
def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
# Generated by Django 5.1.11 on 2025-06-18 16:21
|
||||||
|
|
||||||
|
import authentik.lib.utils.time
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_stages_user_login", "0006_userloginstage_geoip_binding_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="userloginstage",
|
||||||
|
name="remember_device",
|
||||||
|
field=models.TextField(
|
||||||
|
default="days=30",
|
||||||
|
help_text="When set to a non-zero value, authentik will save a cookie with a longer expiry,to remember the device the user is logging in from. (Format: hours=-1;minutes=-2;seconds=-3)",
|
||||||
|
validators=[authentik.lib.utils.time.timedelta_string_validator],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -63,6 +63,15 @@ class UserLoginStage(Stage):
|
|||||||
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
remember_device = models.TextField(
|
||||||
|
default="days=30",
|
||||||
|
validators=[timedelta_string_validator],
|
||||||
|
help_text=_(
|
||||||
|
"When set to a non-zero value, authentik will save a cookie with a longer expiry,"
|
||||||
|
"to remember the device the user is logging in from. "
|
||||||
|
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> type[BaseSerializer]:
|
def serializer(self) -> type[BaseSerializer]:
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
"""Login stage logic"""
|
"""Login stage logic"""
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from hashlib import sha256
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.contrib.auth import login
|
from django.contrib.auth import login
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
from jwt import PyJWTError, decode, encode
|
||||||
from rest_framework.fields import BooleanField, CharField
|
from rest_framework.fields import BooleanField, CharField
|
||||||
|
|
||||||
from authentik.core.models import Session, User
|
from authentik.core.models import AuthenticatedSession, Session, User
|
||||||
from authentik.events.middleware import audit_ignore
|
from authentik.events.middleware import audit_ignore
|
||||||
from authentik.flows.challenge import ChallengeResponse, WithUserInfoChallenge
|
from authentik.flows.challenge import ChallengeResponse, WithUserInfoChallenge
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||||
@@ -16,12 +19,20 @@ from authentik.flows.stage import ChallengeStageView
|
|||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.root.middleware import ClientIPMiddleware
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.password import BACKEND_INBUILT
|
from authentik.stages.password import BACKEND_INBUILT
|
||||||
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
|
from authentik.stages.password.stage import (
|
||||||
|
PLAN_CONTEXT_AUTHENTICATION_BACKEND,
|
||||||
|
PLAN_CONTEXT_METHOD_ARGS,
|
||||||
|
)
|
||||||
from authentik.stages.user_login.middleware import (
|
from authentik.stages.user_login.middleware import (
|
||||||
SESSION_KEY_BINDING_GEO,
|
SESSION_KEY_BINDING_GEO,
|
||||||
SESSION_KEY_BINDING_NET,
|
SESSION_KEY_BINDING_NET,
|
||||||
)
|
)
|
||||||
from authentik.stages.user_login.models import UserLoginStage
|
from authentik.stages.user_login.models import UserLoginStage
|
||||||
|
from authentik.tenants.utils import get_unique_identifier
|
||||||
|
|
||||||
|
COOKIE_NAME_KNOWN_DEVICE = "authentik_device"
|
||||||
|
|
||||||
|
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE = "known_device"
|
||||||
|
|
||||||
|
|
||||||
class UserLoginChallenge(WithUserInfoChallenge):
|
class UserLoginChallenge(WithUserInfoChallenge):
|
||||||
@@ -78,12 +89,63 @@ class UserLoginStageView(ChallengeStageView):
|
|||||||
self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding
|
self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding
|
||||||
self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding
|
self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding
|
||||||
|
|
||||||
def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse:
|
# FIXME: identical function in authenticator_validate
|
||||||
"""Attach the currently pending user to the current session"""
|
@property
|
||||||
|
def cookie_jwt_key(self) -> str:
|
||||||
|
"""Signing key for Known-device Cookie for this stage"""
|
||||||
|
return sha256(
|
||||||
|
f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii")
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
def set_known_device_cookie(self, user: User):
|
||||||
|
"""Set a cookie, valid longer than the session, which denotes that this user
|
||||||
|
has logged in on this device before."""
|
||||||
|
delta = timedelta_from_string(self.executor.current_stage.remember_device)
|
||||||
|
response = self.executor.stage_ok()
|
||||||
|
if delta.total_seconds() < 1:
|
||||||
|
return response
|
||||||
|
expiry = datetime.now() + delta
|
||||||
|
cookie_payload = {
|
||||||
|
"sub": user.uid,
|
||||||
|
"exp": expiry.timestamp(),
|
||||||
|
}
|
||||||
|
cookie = encode(cookie_payload, self.cookie_jwt_key)
|
||||||
|
response.set_cookie(
|
||||||
|
COOKIE_NAME_KNOWN_DEVICE,
|
||||||
|
cookie,
|
||||||
|
expires=expiry,
|
||||||
|
path=settings.SESSION_COOKIE_PATH,
|
||||||
|
domain=settings.SESSION_COOKIE_DOMAIN,
|
||||||
|
samesite=settings.SESSION_COOKIE_SAMESITE,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def is_known_device(self, user: User):
|
||||||
|
"""Returns `True` if the login happened on a "known" device, by the same user."""
|
||||||
|
client_ip = ClientIPMiddleware.get_client_ip(self.request)
|
||||||
|
if AuthenticatedSession.objects.filter(session__last_ip=client_ip).exists():
|
||||||
|
return True
|
||||||
|
if COOKIE_NAME_KNOWN_DEVICE not in self.request.COOKIES:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
payload = decode(
|
||||||
|
self.request.COOKIES[COOKIE_NAME_KNOWN_DEVICE], self.cookie_jwt_key, ["HS256"]
|
||||||
|
)
|
||||||
|
if payload["sub"] == user.uid:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except (PyJWTError, ValueError, TypeError) as exc:
|
||||||
|
self.logger.info("eh", exc=exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def do_login(self, request: HttpRequest, remember: bool | None = None) -> HttpResponse:
|
||||||
|
"""Attach the currently pending user to the current session.
|
||||||
|
`remember` Argument should be `None` if not configured, otherwise set to `True`/`False`
|
||||||
|
representative of the user's choice."""
|
||||||
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
|
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
|
||||||
message = _("No Pending user to login.")
|
message = _("No Pending user to login.")
|
||||||
messages.error(request, message)
|
messages.error(request, message)
|
||||||
self.logger.debug(message)
|
self.logger.warning(message)
|
||||||
return self.executor.stage_invalid()
|
return self.executor.stage_invalid()
|
||||||
backend = self.executor.plan.context.get(
|
backend = self.executor.plan.context.get(
|
||||||
PLAN_CONTEXT_AUTHENTICATION_BACKEND, BACKEND_INBUILT
|
PLAN_CONTEXT_AUTHENTICATION_BACKEND, BACKEND_INBUILT
|
||||||
@@ -91,8 +153,13 @@ class UserLoginStageView(ChallengeStageView):
|
|||||||
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
self.logger.warning("User is not active, login will not work.")
|
self.logger.warning("User is not active, login will not work.")
|
||||||
delta = self.set_session_duration(remember)
|
delta = self.set_session_duration(bool(remember))
|
||||||
self.set_session_ip()
|
self.set_session_ip()
|
||||||
|
# Check if the login request is coming from a known device
|
||||||
|
self.executor.plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||||
|
self.executor.plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||||
|
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, self.is_known_device(user)
|
||||||
|
)
|
||||||
# the `user_logged_in` signal will update the user to write the `last_login` field
|
# the `user_logged_in` signal will update the user to write the `last_login` field
|
||||||
# which we don't want to log as we already have a dedicated login event
|
# which we don't want to log as we already have a dedicated login event
|
||||||
with audit_ignore():
|
with audit_ignore():
|
||||||
@@ -112,4 +179,6 @@ class UserLoginStageView(ChallengeStageView):
|
|||||||
Session.objects.filter(
|
Session.objects.filter(
|
||||||
authenticatedsession__user=user,
|
authenticatedsession__user=user,
|
||||||
).exclude(session_key=self.request.session.session_key).delete()
|
).exclude(session_key=self.request.session.session_key).delete()
|
||||||
|
if remember is None:
|
||||||
|
return self.set_known_device_cookie(user)
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|||||||
@@ -8,17 +8,18 @@ from django.urls import reverse
|
|||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
|
||||||
from authentik.core.models import AuthenticatedSession, Session
|
from authentik.core.models import AuthenticatedSession, Session
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_flow, create_test_user
|
||||||
from authentik.flows.markers import StageMarker
|
from authentik.flows.markers import StageMarker
|
||||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||||
from authentik.flows.tests import FlowTestCase
|
from authentik.flows.tests import FlowTestCase
|
||||||
from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK
|
from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK
|
||||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.root.middleware import ClientIPMiddleware
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.user_login.middleware import (
|
from authentik.stages.user_login.middleware import (
|
||||||
|
SESSION_KEY_BINDING_NET,
|
||||||
BoundSessionMiddleware,
|
BoundSessionMiddleware,
|
||||||
SessionBindingBroken,
|
SessionBindingBroken,
|
||||||
logout_extra,
|
logout_extra,
|
||||||
@@ -31,7 +32,7 @@ class TestUserLoginStage(FlowTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.user = create_test_admin_user()
|
self.user = create_test_user()
|
||||||
|
|
||||||
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
|
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
|
||||||
self.stage = UserLoginStage.objects.create(name="login")
|
self.stage = UserLoginStage.objects.create(name="login")
|
||||||
@@ -247,3 +248,21 @@ class TestUserLoginStage(FlowTestCase):
|
|||||||
request.session = self.client.session
|
request.session = self.client.session
|
||||||
request.user = self.user
|
request.user = self.user
|
||||||
logout_extra(request, cm.exception)
|
logout_extra(request, cm.exception)
|
||||||
|
|
||||||
|
def test_session_binding_broken(self):
|
||||||
|
"""Test session binding"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
session = self.client.session
|
||||||
|
session[Session.Keys.LAST_IP] = "192.0.2.1"
|
||||||
|
session[SESSION_KEY_BINDING_NET] = NetworkBinding.BIND_ASN_NETWORK_IP
|
||||||
|
session.save()
|
||||||
|
|
||||||
|
res = self.client.get(reverse("authentik_api:user-me"))
|
||||||
|
self.assertEqual(res.status_code, 302)
|
||||||
|
self.assertEqual(
|
||||||
|
res.url,
|
||||||
|
reverse(
|
||||||
|
"authentik_flows:default-authentication",
|
||||||
|
)
|
||||||
|
+ f"?{NEXT_ARG_NAME}={reverse("authentik_api:user-me")}",
|
||||||
|
)
|
||||||
|
|||||||
@@ -4754,6 +4754,7 @@
|
|||||||
"add_token",
|
"add_token",
|
||||||
"change_token",
|
"change_token",
|
||||||
"delete_token",
|
"delete_token",
|
||||||
|
"set_token_key",
|
||||||
"view_token",
|
"view_token",
|
||||||
"view_token_key"
|
"view_token_key"
|
||||||
]
|
]
|
||||||
@@ -4890,6 +4891,7 @@
|
|||||||
"authentik_core.preview_user",
|
"authentik_core.preview_user",
|
||||||
"authentik_core.remove_user_from_group",
|
"authentik_core.remove_user_from_group",
|
||||||
"authentik_core.reset_user_password",
|
"authentik_core.reset_user_password",
|
||||||
|
"authentik_core.set_token_key",
|
||||||
"authentik_core.unassign_user_permissions",
|
"authentik_core.unassign_user_permissions",
|
||||||
"authentik_core.view_application",
|
"authentik_core.view_application",
|
||||||
"authentik_core.view_applicationentitlement",
|
"authentik_core.view_applicationentitlement",
|
||||||
@@ -9536,6 +9538,7 @@
|
|||||||
"authentik_core.preview_user",
|
"authentik_core.preview_user",
|
||||||
"authentik_core.remove_user_from_group",
|
"authentik_core.remove_user_from_group",
|
||||||
"authentik_core.reset_user_password",
|
"authentik_core.reset_user_password",
|
||||||
|
"authentik_core.set_token_key",
|
||||||
"authentik_core.unassign_user_permissions",
|
"authentik_core.unassign_user_permissions",
|
||||||
"authentik_core.view_application",
|
"authentik_core.view_application",
|
||||||
"authentik_core.view_applicationentitlement",
|
"authentik_core.view_applicationentitlement",
|
||||||
@@ -10958,6 +10961,7 @@
|
|||||||
"enum": [
|
"enum": [
|
||||||
"apple",
|
"apple",
|
||||||
"openidconnect",
|
"openidconnect",
|
||||||
|
"entraid",
|
||||||
"azuread",
|
"azuread",
|
||||||
"discord",
|
"discord",
|
||||||
"facebook",
|
"facebook",
|
||||||
@@ -15546,6 +15550,12 @@
|
|||||||
],
|
],
|
||||||
"title": "Geoip binding",
|
"title": "Geoip binding",
|
||||||
"description": "Bind sessions created by this stage to the configured GeoIP location"
|
"description": "Bind sessions created by this stage to the configured GeoIP location"
|
||||||
|
},
|
||||||
|
"remember_device": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"title": "Remember device",
|
||||||
|
"description": "When set to a non-zero value, authentik will save a cookie with a longer expiry,to remember the device the user is logging in from. (Format: hours=-1;minutes=-2;seconds=-3)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": []
|
"required": []
|
||||||
|
|||||||
17
go.mod
17
go.mod
@@ -6,18 +6,18 @@ require (
|
|||||||
beryju.io/ldap v0.1.0
|
beryju.io/ldap v0.1.0
|
||||||
github.com/avast/retry-go/v4 v4.6.1
|
github.com/avast/retry-go/v4 v4.6.1
|
||||||
github.com/coreos/go-oidc/v3 v3.14.1
|
github.com/coreos/go-oidc/v3 v3.14.1
|
||||||
github.com/getsentry/sentry-go v0.34.0
|
github.com/getsentry/sentry-go v0.34.1
|
||||||
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
||||||
github.com/go-ldap/ldap/v3 v3.4.11
|
github.com/go-ldap/ldap/v3 v3.4.11
|
||||||
github.com/go-openapi/runtime v0.28.0
|
github.com/go-openapi/runtime v0.28.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
github.com/golang-jwt/jwt/v5 v5.2.3
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/handlers v1.5.2
|
github.com/gorilla/handlers v1.5.2
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/gorilla/securecookie v1.1.2
|
github.com/gorilla/securecookie v1.1.2
|
||||||
github.com/gorilla/sessions v1.4.0
|
github.com/gorilla/sessions v1.4.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/grafana/pyroscope-go v1.2.2
|
github.com/grafana/pyroscope-go v1.2.3
|
||||||
github.com/jellydator/ttlcache/v3 v3.4.0
|
github.com/jellydator/ttlcache/v3 v3.4.0
|
||||||
github.com/mitchellh/mapstructure v1.5.0
|
github.com/mitchellh/mapstructure v1.5.0
|
||||||
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
|
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
|
||||||
@@ -29,10 +29,10 @@ require (
|
|||||||
github.com/spf13/cobra v1.9.1
|
github.com/spf13/cobra v1.9.1
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/wwt/guac v1.3.2
|
github.com/wwt/guac v1.3.2
|
||||||
goauthentik.io/api/v3 v3.2025063.1
|
goauthentik.io/api/v3 v3.2025063.5
|
||||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
golang.org/x/sync v0.15.0
|
golang.org/x/sync v0.16.0
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
|
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
|
||||||
)
|
)
|
||||||
@@ -77,9 +77,10 @@ require (
|
|||||||
go.opentelemetry.io/otel v1.24.0 // indirect
|
go.opentelemetry.io/otel v1.24.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.24.0 // indirect
|
go.opentelemetry.io/otel/trace v1.24.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0 // indirect
|
golang.org/x/crypto v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/net v0.40.0 // indirect
|
||||||
golang.org/x/text v0.24.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
|
golang.org/x/text v0.25.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.5 // indirect
|
google.golang.org/protobuf v1.36.5 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
36
go.sum
36
go.sum
@@ -71,8 +71,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
|
|||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
||||||
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4=
|
github.com/getsentry/sentry-go v0.34.1 h1:HSjc1C/OsnZttohEPrrqKH42Iud0HuLCXpv8cU1pWcw=
|
||||||
github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
|
github.com/getsentry/sentry-go v0.34.1/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||||
@@ -115,8 +115,8 @@ github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+Gr
|
|||||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||||
github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58=
|
github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58=
|
||||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
github.com/golang-jwt/jwt/v5 v5.2.3 h1:kkGXqQOBSDDWRhWNXTFpqGSCMyh/PLnqUvMGJPDJDs0=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.3/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
@@ -180,8 +180,8 @@ github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2e
|
|||||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
github.com/grafana/pyroscope-go v1.2.3 h1:Rp8mjqqGqmRDvV6XYmuedUAv7wVnQJK/M1pBt6uNwxU=
|
||||||
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
github.com/grafana/pyroscope-go v1.2.3/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||||
@@ -298,16 +298,16 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
|
|||||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
goauthentik.io/api/v3 v3.2025063.1 h1:zvKhZTESgMY/SNiLuTs7G0YleBnev1v7+S9Xd6PZ9bc=
|
goauthentik.io/api/v3 v3.2025063.5 h1:j5el9/qI/72Q5x5QAiMzgQTswMj2TK3h74OaBcFEtkI=
|
||||||
goauthentik.io/api/v3 v3.2025063.1/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
goauthentik.io/api/v3 v3.2025063.5/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
@@ -366,8 +366,8 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
|
||||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
@@ -384,8 +384,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -413,15 +413,15 @@ golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func (fe *FlowExecutor) solveChallenge_AuthenticatorValidate(challenge *api.Chal
|
|||||||
var deviceChallenge *api.DeviceChallenge
|
var deviceChallenge *api.DeviceChallenge
|
||||||
inner := api.NewAuthenticatorValidationChallengeResponseRequest()
|
inner := api.NewAuthenticatorValidationChallengeResponseRequest()
|
||||||
for _, devCh := range challenge.AuthenticatorValidationChallenge.DeviceChallenges {
|
for _, devCh := range challenge.AuthenticatorValidationChallenge.DeviceChallenges {
|
||||||
if devCh.DeviceClass == string(api.DEVICECLASSESENUM_DUO) {
|
if devCh.DeviceClass == api.DEVICECLASSESENUM_DUO {
|
||||||
deviceChallenge = &devCh
|
deviceChallenge = &devCh
|
||||||
devId, err := strconv.ParseInt(deviceChallenge.DeviceUid, 10, 32)
|
devId, err := strconv.ParseInt(deviceChallenge.DeviceUid, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,8 +38,8 @@ func (fe *FlowExecutor) solveChallenge_AuthenticatorValidate(challenge *api.Chal
|
|||||||
inner.SelectedChallenge = (*api.DeviceChallengeRequest)(deviceChallenge)
|
inner.SelectedChallenge = (*api.DeviceChallengeRequest)(deviceChallenge)
|
||||||
inner.Duo = &devId32
|
inner.Duo = &devId32
|
||||||
}
|
}
|
||||||
if devCh.DeviceClass == string(api.DEVICECLASSESENUM_STATIC) ||
|
if devCh.DeviceClass == api.DEVICECLASSESENUM_STATIC ||
|
||||||
devCh.DeviceClass == string(api.DEVICECLASSESENUM_TOTP) {
|
devCh.DeviceClass == api.DEVICECLASSESENUM_TOTP {
|
||||||
// Only use code-based devices if we have a code in the entered password,
|
// Only use code-based devices if we have a code in the entered password,
|
||||||
// and we haven't selected a push device yet
|
// and we haven't selected a push device yet
|
||||||
if deviceChallenge == nil && fe.getAnswer(StageAuthenticatorValidate) != "" {
|
if deviceChallenge == nil && fe.getAnswer(StageAuthenticatorValidate) != "" {
|
||||||
|
|||||||
@@ -100,6 +100,9 @@ elif [[ "$1" == "healthcheck" ]]; then
|
|||||||
elif [[ "$1" == "dump_config" ]]; then
|
elif [[ "$1" == "dump_config" ]]; then
|
||||||
shift
|
shift
|
||||||
exec python -m authentik.lib.config $@
|
exec python -m authentik.lib.config $@
|
||||||
|
elif [[ "$1" == "support" ]]; then
|
||||||
|
wait_for_db
|
||||||
|
exec python -m lifecycle.support
|
||||||
elif [[ "$1" == "debug" ]]; then
|
elif [[ "$1" == "debug" ]]; then
|
||||||
exec sleep infinity
|
exec sleep infinity
|
||||||
else
|
else
|
||||||
|
|||||||
9
lifecycle/aws/package-lock.json
generated
9
lifecycle/aws/package-lock.json
generated
@@ -9,7 +9,7 @@
|
|||||||
"version": "0.0.0",
|
"version": "0.0.0",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"aws-cdk": "^2.1020.1",
|
"aws-cdk": "^2.1021.0",
|
||||||
"cross-env": "^7.0.3"
|
"cross-env": "^7.0.3"
|
||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
@@ -17,10 +17,11 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/aws-cdk": {
|
"node_modules/aws-cdk": {
|
||||||
"version": "2.1020.1",
|
"version": "2.1021.0",
|
||||||
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1020.1.tgz",
|
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1021.0.tgz",
|
||||||
"integrity": "sha512-4UG9qzf6ZSDjINubcukPZChVj6PvDJAHiURAw0jYSkUhObPkX7Zo9uNUIlXzrM+hpB2N2jwRKY9b3sN+KDbtAQ==",
|
"integrity": "sha512-kE557b4N9UFWax+7km3R6D56o4tGhpzOks/lRDugaoC8su3mocLCXJhb954b/IRl0ipnbZnY/Sftq+RQ/sxivg==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
|
"license": "Apache-2.0",
|
||||||
"bin": {
|
"bin": {
|
||||||
"cdk": "bin/cdk"
|
"cdk": "bin/cdk"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
"node": ">=20"
|
"node": ">=20"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"aws-cdk": "^2.1020.1",
|
"aws-cdk": "^2.1021.0",
|
||||||
"cross-env": "^7.0.3"
|
"cross-env": "^7.0.3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user