mirror of
https://github.com/goauthentik/authentik
synced 2026-04-26 01:25:02 +02:00
Compare commits
187 Commits
docs-event
...
admin/vers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbce9611d2 | ||
|
|
e6643a69cd | ||
|
|
0fdeaee559 | ||
|
|
f9fd1bbf09 | ||
|
|
3ba3b11a76 | ||
|
|
19e558e916 | ||
|
|
e15fadfedd | ||
|
|
52854e61c7 | ||
|
|
53aa0113ca | ||
|
|
9f71face62 | ||
|
|
2fadefb5b4 | ||
|
|
23e92bceae | ||
|
|
1ff2eea20a | ||
|
|
abcd2179bf | ||
|
|
6a4b5850a0 | ||
|
|
821c8c36cd | ||
|
|
8838efe3c0 | ||
|
|
433a4a3037 | ||
|
|
2d69a67e9d | ||
|
|
1294cc64e8 | ||
|
|
910326a05a | ||
|
|
9257b3e570 | ||
|
|
cdd18a7e5a | ||
|
|
88bea46648 | ||
|
|
295090a80b | ||
|
|
bff607a5c3 | ||
|
|
bfb2fb4fcf | ||
|
|
93015b0fce | ||
|
|
9b6c0d3f1a | ||
|
|
66e95ddb20 | ||
|
|
c5d8524a7d | ||
|
|
a4761064c2 | ||
|
|
b0de8bf71f | ||
|
|
32100fd3b9 | ||
|
|
4815e97162 | ||
|
|
dee99c38bb | ||
|
|
a024056b62 | ||
|
|
a8dc21b707 | ||
|
|
7ccda743df | ||
|
|
0c795dd077 | ||
|
|
5df9ed3582 | ||
|
|
a47b4934a5 | ||
|
|
338a6e74f4 | ||
|
|
8897af1048 | ||
|
|
56ec3f7def | ||
|
|
53fd893d91 | ||
|
|
f7d9a8cafe | ||
|
|
f97c1071f3 | ||
|
|
4da1115a7c | ||
|
|
63b1ccd4c3 | ||
|
|
63aa7f4684 | ||
|
|
d997930b60 | ||
|
|
a088a62981 | ||
|
|
118e05f256 | ||
|
|
b30500094f | ||
|
|
21af51ba59 | ||
|
|
87da0497e0 | ||
|
|
87317d6e7f | ||
|
|
071305da18 | ||
|
|
1dc8ed5e55 | ||
|
|
dc8dee985f | ||
|
|
2b20b06baa | ||
|
|
6cab1f85e4 | ||
|
|
f836c38b18 | ||
|
|
07e373e505 | ||
|
|
e361d38978 | ||
|
|
3ba1691db6 | ||
|
|
7c2987ea32 | ||
|
|
4ca88caf07 | ||
|
|
6c939341b0 | ||
|
|
4142584788 | ||
|
|
f6fbafd280 | ||
|
|
7c9555bee8 | ||
|
|
82cd64dfe7 | ||
|
|
28f0b48e33 | ||
|
|
38c02dc490 | ||
|
|
79505969db | ||
|
|
9870888456 | ||
|
|
5c06e1920e | ||
|
|
1506ad8aa4 | ||
|
|
21b6204c90 | ||
|
|
05621735cb | ||
|
|
f9ffd35ab8 | ||
|
|
c3ded3a835 | ||
|
|
7629c22050 | ||
|
|
29a66410fd | ||
|
|
f147d40c5f | ||
|
|
15b556c1be | ||
|
|
522e8a26a2 | ||
|
|
403d762f65 | ||
|
|
cbc65ffd74 | ||
|
|
9a9bafdfb4 | ||
|
|
198d2a1a8a | ||
|
|
239edace16 | ||
|
|
370d5ff0c0 | ||
|
|
635b09621b | ||
|
|
4335498ac5 | ||
|
|
72af009de8 | ||
|
|
3a07d5d829 | ||
|
|
7122891f0f | ||
|
|
c32d6cc75e | ||
|
|
eaf6be74f3 | ||
|
|
c35650afbd | ||
|
|
a1f9ff8b7d | ||
|
|
962f7513ba | ||
|
|
0ec5ea69ef | ||
|
|
d8a3098329 | ||
|
|
80ad97b28d | ||
|
|
cd98767dbc | ||
|
|
30f09e8c45 | ||
|
|
154bcb58a6 | ||
|
|
597945edf1 | ||
|
|
38d6e39fe0 | ||
|
|
1a6065f72a | ||
|
|
d07e0f015d | ||
|
|
7f931917fa | ||
|
|
d7fb684292 | ||
|
|
bd0fa7be98 | ||
|
|
2907808a7e | ||
|
|
c53016b2e5 | ||
|
|
4479587baa | ||
|
|
08d24a1871 | ||
|
|
42ea8bb3ed | ||
|
|
c9a07fa18d | ||
|
|
4130446cbc | ||
|
|
b4aecbd782 | ||
|
|
981d2af109 | ||
|
|
db96e13813 | ||
|
|
3d39cc4974 | ||
|
|
d36ec31224 | ||
|
|
bb7a2002f2 | ||
|
|
8fff802936 | ||
|
|
0f3fb9f93c | ||
|
|
1e76d1f883 | ||
|
|
140d9fe95c | ||
|
|
67eacbe860 | ||
|
|
435b815617 | ||
|
|
0459feeb8a | ||
|
|
4e6e730014 | ||
|
|
1231fc8237 | ||
|
|
b7f320d7cc | ||
|
|
35073b03ac | ||
|
|
b3b8b8bb1c | ||
|
|
17ee42f98f | ||
|
|
e8f95a4b08 | ||
|
|
decc0c1ae1 | ||
|
|
716bfa9043 | ||
|
|
4d8feb15e3 | ||
|
|
d50eab08e8 | ||
|
|
09b2a2bd4f | ||
|
|
210d9cf31c | ||
|
|
a0291a1b32 | ||
|
|
790ae0c3d8 | ||
|
|
8fc744fb56 | ||
|
|
392011cac4 | ||
|
|
15316b6bae | ||
|
|
dccb1d01f0 | ||
|
|
e8cd762c6e | ||
|
|
12847d9a87 | ||
|
|
6c4cb06825 | ||
|
|
aa8e971477 | ||
|
|
6c02d5a316 | ||
|
|
2f3259bf13 | ||
|
|
8b7a538419 | ||
|
|
d0127d83c9 | ||
|
|
262ca4aea9 | ||
|
|
9923cb73a6 | ||
|
|
b58a8774d4 | ||
|
|
bf6a37a5dc | ||
|
|
1eda16cbd6 | ||
|
|
8c3397e5f2 | ||
|
|
974b4d5c82 | ||
|
|
00daba0d0c | ||
|
|
63d547194c | ||
|
|
ec171bd282 | ||
|
|
155fa433b3 | ||
|
|
7a88fd5b6b | ||
|
|
7d9fb85827 | ||
|
|
0021e5fa25 | ||
|
|
6919838c12 | ||
|
|
9841d976e1 | ||
|
|
87051cf588 | ||
|
|
dec7ac675c | ||
|
|
53e17ff877 | ||
|
|
4635a07edd | ||
|
|
f70b74fc03 | ||
|
|
2713a5ec88 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -31,4 +31,4 @@ If changes to the frontend have been made
|
||||
If applicable
|
||||
|
||||
- [ ] The documentation has been updated
|
||||
- [ ] The documentation has been formatted (`make website`)
|
||||
- [ ] The documentation has been formatted (`make docs`)
|
||||
|
||||
4
.github/workflows/api-ts-publish.yml
vendored
4
.github/workflows/api-ts-publish.yml
vendored
@@ -27,8 +27,8 @@ jobs:
|
||||
- name: Publish package
|
||||
working-directory: gen-ts-api/
|
||||
run: |
|
||||
npm ci
|
||||
npm publish
|
||||
npm i
|
||||
npm publish --tag generated
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_PUBLISH_TOKEN }}
|
||||
- name: Upgrade /web
|
||||
|
||||
@@ -62,7 +62,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
go build -o /go/authentik ./cmd/server
|
||||
|
||||
# Stage 3: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.0 AS geoip
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.1 AS geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||
ENV GEOIPUPDATE_VERBOSE="1"
|
||||
@@ -75,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 4: Download uv
|
||||
FROM ghcr.io/astral-sh/uv:0.7.18 AS uv
|
||||
FROM ghcr.io/astral-sh/uv:0.7.21 AS uv
|
||||
# Stage 5: Base python image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
|
||||
|
||||
|
||||
43
Makefile
43
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: gen dev-reset all clean test web website
|
||||
.PHONY: gen dev-reset all clean test web docs
|
||||
|
||||
SHELL := /usr/bin/env bash
|
||||
.SHELLFLAGS += ${SHELLFLAGS} -e -o pipefail
|
||||
@@ -73,7 +73,7 @@ core-i18n-extract:
|
||||
--ignore website \
|
||||
-l en
|
||||
|
||||
install: web-install website-install core-install ## Install all requires dependencies for `web`, `website` and `core`
|
||||
install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core`
|
||||
|
||||
dev-drop-db:
|
||||
dropdb -U ${pg_user} -h ${pg_host} ${pg_name}
|
||||
@@ -183,18 +183,23 @@ gen-dev-config: ## Generate a local development config file
|
||||
|
||||
gen: gen-build gen-client-ts
|
||||
|
||||
#########################
|
||||
## Node.js
|
||||
#########################
|
||||
|
||||
node-install: ## Install the necessary libraries to build Node.js packages
|
||||
npm ci
|
||||
npm ci --prefix web
|
||||
|
||||
#########################
|
||||
## Web
|
||||
#########################
|
||||
|
||||
web-build: web-install ## Build the Authentik UI
|
||||
web-build: node-install ## Build the Authentik UI
|
||||
cd web && npm run build
|
||||
|
||||
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
|
||||
|
||||
web-install: ## Install the necessary libraries to build the Authentik UI
|
||||
cd web && npm ci
|
||||
|
||||
web-test: ## Run tests for the Authentik UI
|
||||
cd web && npm run test
|
||||
|
||||
@@ -221,22 +226,28 @@ web-i18n-extract:
|
||||
cd web && npm run extract-locales
|
||||
|
||||
#########################
|
||||
## Website
|
||||
## Docs
|
||||
#########################
|
||||
|
||||
website: website-lint-fix website-build ## Automatically fix formatting issues in the Authentik website/docs source code, lint the code, and compile it
|
||||
docs: docs-lint-fix docs-build ## Automatically fix formatting issues in the Authentik docs source code, lint the code, and compile it
|
||||
|
||||
website-install:
|
||||
cd website && npm ci
|
||||
docs-install:
|
||||
npm ci --prefix website
|
||||
|
||||
website-lint-fix: lint-codespell
|
||||
cd website && npm run prettier
|
||||
docs-lint-fix: lint-codespell
|
||||
npm run prettier --prefix website
|
||||
|
||||
website-build:
|
||||
cd website && npm run build
|
||||
docs-build:
|
||||
npm run build --prefix website
|
||||
|
||||
website-watch: ## Build and watch the documentation website, updating automatically
|
||||
cd website && npm run watch
|
||||
docs-watch: ## Build and watch the topics documentation
|
||||
npm run start --prefix website
|
||||
|
||||
docs-integrations-build:
|
||||
npm run build --prefix website -w integrations
|
||||
|
||||
docs-integrations-watch: ## Build and watch the Integrations documentation
|
||||
npm run start --prefix website -w integrations
|
||||
|
||||
#########################
|
||||
## Docker
|
||||
|
||||
@@ -42,7 +42,11 @@ class Exporter:
|
||||
if model in self.excluded_models:
|
||||
continue
|
||||
for obj in self.get_model_instances(model):
|
||||
yield BlueprintEntry.from_model(obj)
|
||||
yield BlueprintEntry.from_model(self.alter_model(obj))
|
||||
|
||||
def alter_model(self, model: Model):
|
||||
"""Hook to modify the model before exporting"""
|
||||
return model
|
||||
|
||||
def get_model_instances(self, model: type[Model]) -> QuerySet:
|
||||
"""Return a queryset for `model`. Can be used to filter some
|
||||
|
||||
@@ -52,6 +52,27 @@ class TestBrands(APITestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_brand_subdomain_same_suffix(self):
|
||||
"""Test Current brand API"""
|
||||
Brand.objects.all().delete()
|
||||
Brand.objects.create(domain="bar.baz", branding_title="custom")
|
||||
Brand.objects.create(domain="foo.bar.baz", branding_title="custom")
|
||||
self.assertJSONEqual(
|
||||
self.client.get(
|
||||
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "foo.bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
)
|
||||
|
||||
def test_fallback(self):
|
||||
"""Test fallback brand"""
|
||||
Brand.objects.all().delete()
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from django.db.models import F, Q
|
||||
from django.db.models import Value as V
|
||||
from django.db.models.functions import Length
|
||||
from django.http.request import HttpRequest
|
||||
from django.utils.html import _json_script_escapes
|
||||
from django.utils.safestring import mark_safe
|
||||
@@ -20,9 +21,9 @@ DEFAULT_BRAND = Brand(domain="fallback")
|
||||
def get_brand_for_request(request: HttpRequest) -> Brand:
|
||||
"""Get brand object for current request"""
|
||||
db_brands = (
|
||||
Brand.objects.annotate(host_domain=V(request.get_host()))
|
||||
Brand.objects.annotate(host_domain=V(request.get_host()), match_length=Length("domain"))
|
||||
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
|
||||
.order_by("default")
|
||||
.order_by("-match_length", "default")
|
||||
)
|
||||
brands = list(db_brands.all())
|
||||
if len(brands) < 1:
|
||||
|
||||
@@ -149,10 +149,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
return applications
|
||||
|
||||
def _filter_applications_with_launch_url(
|
||||
self, pagined_apps: Iterator[Application]
|
||||
self, paginated_apps: Iterator[Application]
|
||||
) -> list[Application]:
|
||||
applications = []
|
||||
for app in pagined_apps:
|
||||
for app in paginated_apps:
|
||||
if app.get_launch_url():
|
||||
applications.append(app)
|
||||
return applications
|
||||
|
||||
@@ -11,7 +11,6 @@ from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.types import PolicyRequest
|
||||
|
||||
PROPERTY_MAPPING_TIME = Histogram(
|
||||
@@ -69,12 +68,11 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||
# For dry-run requests we don't save exceptions
|
||||
if self.dry_run:
|
||||
return
|
||||
error_string = exception_to_string(exc)
|
||||
event = Event.new(
|
||||
EventAction.PROPERTY_MAPPING_EXCEPTION,
|
||||
expression=expression_source,
|
||||
message=error_string,
|
||||
)
|
||||
message="Failed to execute property mapping",
|
||||
).with_exception(exc)
|
||||
if "request" in self._context:
|
||||
req: PolicyRequest = self._context["request"]
|
||||
if req.http_request:
|
||||
|
||||
24
authentik/core/migrations/0049_alter_token_options.py
Normal file
24
authentik/core/migrations/0049_alter_token_options.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Generated by Django 5.1.11 on 2025-07-03 13:08
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0048_delete_oldauthenticatedsession_content_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="token",
|
||||
options={
|
||||
"permissions": [
|
||||
("view_token_key", "View token's key"),
|
||||
("set_token_key", "Set a token's key"),
|
||||
],
|
||||
"verbose_name": "Token",
|
||||
"verbose_name_plural": "Tokens",
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -953,7 +953,10 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
|
||||
models.Index(fields=["identifier"]),
|
||||
models.Index(fields=["key"]),
|
||||
]
|
||||
permissions = [("view_token_key", _("View token's key"))]
|
||||
permissions = [
|
||||
("view_token_key", _("View token's key")),
|
||||
("set_token_key", _("Set a token's key")),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
description = f"{self.identifier}"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from django.http import HttpResponse
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
@@ -10,16 +11,21 @@ from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.validators import UniqueValidator
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.bundle import generate_support_bundle
|
||||
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
|
||||
from authentik.enterprise.models import License
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
from authentik.tenants.utils import get_unique_identifier
|
||||
|
||||
|
||||
@@ -53,6 +59,7 @@ class LicenseSerializer(ModelSerializer):
|
||||
"external_users",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"key": {"validators": [UniqueValidator(queryset=License.objects.all())]},
|
||||
"name": {"read_only": True},
|
||||
"expiry": {"read_only": True},
|
||||
"internal_users": {"read_only": True},
|
||||
@@ -145,3 +152,24 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
)
|
||||
response.is_valid(raise_exception=True)
|
||||
return Response(response.data)
|
||||
|
||||
|
||||
class BinaryRenderer(BaseRenderer):
|
||||
media_type = "application/gzip"
|
||||
format = "bin"
|
||||
|
||||
|
||||
class SupportBundleView(APIView):
|
||||
"""Generate a support bundle."""
|
||||
|
||||
permission_classes = [HasPermission("authentik_rbac.view_system_info")]
|
||||
pagination_class = None
|
||||
filter_backends = []
|
||||
renderer_classes = [BinaryRenderer]
|
||||
|
||||
@extend_schema(responses=bytes, request=None)
|
||||
def post(self, request: Request) -> Response:
|
||||
"""Generate a support bundle."""
|
||||
response = HttpResponse(generate_support_bundle(), content_type=BinaryRenderer.media_type)
|
||||
response["Content-Disposition"] = 'attachment; filename="authentik_support.tgz"'
|
||||
return response
|
||||
|
||||
@@ -65,13 +65,17 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
data[field.name] = deepcopy(field_value)
|
||||
return cleanse_dict(data)
|
||||
|
||||
def diff(self, before: dict, after: dict) -> dict:
|
||||
def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
|
||||
"""Generate diff between dicts"""
|
||||
diff = {}
|
||||
for key, value in before.items():
|
||||
if update_fields and key not in update_fields:
|
||||
continue
|
||||
if after.get(key) != value:
|
||||
diff[key] = {"previous_value": value, "new_value": after.get(key)}
|
||||
for key, value in after.items():
|
||||
if update_fields and key not in update_fields:
|
||||
continue
|
||||
if key not in before and key not in diff and before.get(key) != value:
|
||||
diff[key] = {"previous_value": before.get(key), "new_value": value}
|
||||
return sanitize_item(diff)
|
||||
@@ -95,6 +99,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
instance: Model,
|
||||
created: bool,
|
||||
thread_kwargs: dict | None = None,
|
||||
update_fields: list[str] | None = None,
|
||||
**_,
|
||||
):
|
||||
if not self.enabled:
|
||||
@@ -108,7 +113,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
prev_state = {}
|
||||
# Get current state
|
||||
new_state = self.serialize_simple(instance)
|
||||
diff = self.diff(prev_state, new_state)
|
||||
diff = self.diff(prev_state, new_state, update_fields)
|
||||
thread_kwargs["diff"] = diff
|
||||
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.enterprise.audit.middleware import EnterpriseAuditMiddleware
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.generators import generate_id
|
||||
@@ -208,3 +209,23 @@ class TestEnterpriseAudit(APITestCase):
|
||||
diff,
|
||||
{"users": {"remove": [user.pk]}},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_diff_update_fields(self):
|
||||
"""Test update audit log"""
|
||||
self.client.force_login(self.user)
|
||||
diff = EnterpriseAuditMiddleware(None).diff(
|
||||
{
|
||||
"foo": "bar",
|
||||
"is_active": False,
|
||||
},
|
||||
{
|
||||
"foo": "baz",
|
||||
"is_active": True,
|
||||
},
|
||||
update_fields=["is_active"],
|
||||
)
|
||||
self.assertEqual(diff, {"is_active": {"new_value": True, "previous_value": False}})
|
||||
|
||||
53
authentik/enterprise/bundle.py
Normal file
53
authentik/enterprise/bundle.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import re
|
||||
from io import BytesIO
|
||||
from tarfile import TarInfo, open
|
||||
|
||||
from django.db.models import Model
|
||||
from django.db.models.fields import CharField, SlugField, TextField
|
||||
from django.db.models.fields.json import JSONField
|
||||
|
||||
from authentik.blueprints.v1.exporter import Exporter
|
||||
from authentik.core.models import User
|
||||
from lifecycle.support import encrypt, generate
|
||||
|
||||
SENSITIVE_VALUE_PLACEHOLDER = "<REDACTED>"
|
||||
|
||||
|
||||
class SupportExporter(Exporter):
|
||||
"""Blueprint exporter which censors sensitive model attributes"""
|
||||
|
||||
sensitive_fields = re.compile(
|
||||
# Partially taken from Django's SafeExceptionReporterFilter
|
||||
"API|AUTH|TOKEN|KEY|SECRET|PASS|SIGNATURE|CREDENTIALS",
|
||||
re.I,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.excluded_models.append(User)
|
||||
|
||||
def alter_model(self, model: Model):
|
||||
for field in model._meta.fields:
|
||||
if not self.sensitive_fields.search(field.name):
|
||||
continue
|
||||
if isinstance(field, TextField | CharField | SlugField):
|
||||
setattr(model, field.name, SENSITIVE_VALUE_PLACEHOLDER)
|
||||
elif isinstance(field, JSONField):
|
||||
setattr(model, field.name, {})
|
||||
return model
|
||||
|
||||
|
||||
def generate_support_bundle():
|
||||
fh = BytesIO()
|
||||
exporter = SupportExporter()
|
||||
files = {
|
||||
"authentik/support.jwe": encrypt(generate()),
|
||||
"authentik/blueprint.yaml": exporter.export_to_string(),
|
||||
}
|
||||
with open(fileobj=fh, mode="w:gz") as tar:
|
||||
for path, file in files.items():
|
||||
info = TarInfo(path)
|
||||
info.size = len(file)
|
||||
tar.addfile(info, BytesIO(file.encode()))
|
||||
final_data = fh.getvalue()
|
||||
return final_data
|
||||
@@ -6,7 +6,7 @@ from djangoql.ast import Name
|
||||
from djangoql.exceptions import DjangoQLError
|
||||
from djangoql.queryset import apply_search
|
||||
from djangoql.schema import DjangoQLSchema
|
||||
from rest_framework.filters import BaseFilterBackend, SearchFilter
|
||||
from rest_framework.filters import SearchFilter
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -39,7 +39,7 @@ class BaseSchema(DjangoQLSchema):
|
||||
return super().resolve_name(name)
|
||||
|
||||
|
||||
class QLSearch(BaseFilterBackend):
|
||||
class QLSearch(SearchFilter):
|
||||
"""rest_framework search filter which uses DjangoQL"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -16,7 +16,7 @@ from authentik.stages.authenticator.models import Device
|
||||
|
||||
|
||||
class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
"""Setup Google Chrome Device-trust connection"""
|
||||
"""Setup Google Chrome Device Trust connection"""
|
||||
|
||||
credentials = models.JSONField()
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
from authentik.stages.user_login.stage import PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE
|
||||
|
||||
# Header we get from chrome that initiates verified access
|
||||
HEADER_DEVICE_TRUST = "X-Device-Trust"
|
||||
@@ -27,6 +28,8 @@ HEADER_ACCESS_CHALLENGE_RESPONSE = "X-Verified-Access-Challenge-Response"
|
||||
# Header value for x-device-trust that initiates the flow
|
||||
DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess"
|
||||
|
||||
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS = "endpoints"
|
||||
|
||||
|
||||
@method_decorator(xframe_options_sameorigin, name="dispatch")
|
||||
class GoogleChromeDeviceTrustConnector(View):
|
||||
@@ -81,7 +84,14 @@ class GoogleChromeDeviceTrustConnector(View):
|
||||
)
|
||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint")
|
||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault("endpoints", [])
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS]["endpoints"].append(response)
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS, []
|
||||
)
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS][PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS].append(
|
||||
response
|
||||
)
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, True
|
||||
)
|
||||
request.session[SESSION_KEY_PLAN] = flow_plan
|
||||
return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html")
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""API URLs"""
|
||||
|
||||
from authentik.enterprise.api import LicenseViewSet
|
||||
from django.urls import path
|
||||
|
||||
from authentik.enterprise.api import LicenseViewSet, SupportBundleView
|
||||
|
||||
api_urlpatterns = [
|
||||
("enterprise/license", LicenseViewSet),
|
||||
path(
|
||||
"enterprise/support_bundle/", SupportBundleView.as_view(), name="enterprise_support_bundle"
|
||||
),
|
||||
]
|
||||
|
||||
@@ -20,7 +20,7 @@ from authentik.core.models import Group, User
|
||||
from authentik.events.models import Event, EventAction, Notification
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.lib.sentry import should_ignore_exception
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.stages.authenticator_static.models import StaticToken
|
||||
|
||||
IGNORED_MODELS = tuple(
|
||||
@@ -170,14 +170,16 @@ class AuditMiddleware:
|
||||
thread = EventNewThread(
|
||||
EventAction.SUSPICIOUS_REQUEST,
|
||||
request,
|
||||
message=exception_to_string(exception),
|
||||
message=str(exception),
|
||||
exception=exception_to_dict(exception),
|
||||
)
|
||||
thread.run()
|
||||
elif not should_ignore_exception(exception):
|
||||
thread = EventNewThread(
|
||||
EventAction.SYSTEM_EXCEPTION,
|
||||
request,
|
||||
message=exception_to_string(exception),
|
||||
message=str(exception),
|
||||
exception=exception_to_dict(exception),
|
||||
)
|
||||
thread.run()
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from authentik.events.utils import (
|
||||
)
|
||||
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
@@ -163,6 +164,12 @@ class Event(SerializerModel, ExpiringModel):
|
||||
event = Event(action=action, app=app, context=cleaned_kwargs)
|
||||
return event
|
||||
|
||||
def with_exception(self, exc: Exception) -> "Event":
|
||||
"""Add data from 'exc' to the event in a database-saveable format"""
|
||||
self.context.setdefault("message", str(exc))
|
||||
self.context["exception"] = exception_to_dict(exc)
|
||||
return self
|
||||
|
||||
def set_user(self, user: User) -> "Event":
|
||||
"""Set `.user` based on user, ensuring the correct attributes are copied.
|
||||
This should only be used when self.from_http is *not* used."""
|
||||
|
||||
@@ -127,8 +127,8 @@ class SystemTask(TenantTask):
|
||||
)
|
||||
Event.new(
|
||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
||||
).save()
|
||||
message=f"Task {self.__name__} encountered an error",
|
||||
).with_exception(exc).save()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -62,6 +62,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||
policy_engine.mode = PolicyEngineMode.MODE_ANY
|
||||
policy_engine.empty_result = False
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.request.obj = event
|
||||
policy_engine.request.context["event"] = event
|
||||
policy_engine.build()
|
||||
result = policy_engine.result
|
||||
|
||||
@@ -56,7 +56,6 @@ from authentik.flows.planner import (
|
||||
)
|
||||
from authentik.flows.stage import AccessDeniedStage, StageView
|
||||
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.reflection import all_subclasses, class_to_path
|
||||
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@@ -239,8 +238,8 @@ class FlowExecutorView(APIView):
|
||||
capture_exception(exc)
|
||||
Event.new(
|
||||
action=EventAction.SYSTEM_EXCEPTION,
|
||||
message=exception_to_string(exc),
|
||||
).from_http(self.request)
|
||||
message="System exception during flow execution.",
|
||||
).with_exception(exc).from_http(self.request)
|
||||
challenge = FlowErrorChallenge(self.request, exc)
|
||||
challenge.is_valid(raise_exception=True)
|
||||
return to_stage_response(self.request, HttpChallengeResponse(challenge))
|
||||
|
||||
@@ -14,7 +14,6 @@ from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.exceptions import ControlFlowException
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
@@ -106,9 +105,9 @@ class BaseOutgoingSyncClient[
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
message="Failed to evaluate property-mapping",
|
||||
mapping=exc.mapping,
|
||||
).save()
|
||||
).with_exception(exc).save()
|
||||
raise StopSync(exc, obj, exc.mapping) from exc
|
||||
if not raw_final_object:
|
||||
raise StopSync(ValueError("No mappings configured"), obj)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from traceback import extract_tb
|
||||
|
||||
from structlog.tracebacks import ExceptionDictTransformer
|
||||
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
TRACEBACK_HEADER = "Traceback (most recent call last):"
|
||||
@@ -17,3 +19,8 @@ def exception_to_string(exc: Exception) -> str:
|
||||
f"{class_to_path(exc.__class__)}: {str(exc)}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def exception_to_dict(exc: Exception) -> dict:
|
||||
"""Format exception as a dictionary"""
|
||||
return ExceptionDictTransformer()((type(exc), exc, exc.__traceback__))
|
||||
|
||||
@@ -35,7 +35,6 @@ from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.outposts.controllers.k8s.utils import get_namespace
|
||||
|
||||
OUR_VERSION = parse(__version__)
|
||||
@@ -326,9 +325,8 @@ class Outpost(SerializerModel, ManagedModel):
|
||||
"While setting the permissions for the service-account, a "
|
||||
"permission was not found: Check "
|
||||
"https://goauthentik.io/docs/troubleshooting/missing_permission"
|
||||
)
|
||||
+ exception_to_string(exc),
|
||||
).set_user(user).save()
|
||||
),
|
||||
).with_exception(exc).set_user(user).save()
|
||||
else:
|
||||
app_label, perm = model_or_perm.split(".")
|
||||
permission = Permission.objects.filter(
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""authentik policy engine"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Iterable
|
||||
from multiprocessing import Pipe, current_process
|
||||
from multiprocessing.connection import Connection
|
||||
from time import perf_counter
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Count, Q, QuerySet
|
||||
from django.http import HttpRequest
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
@@ -67,14 +67,11 @@ class PolicyEngine:
|
||||
self.__processes: list[PolicyProcessInfo] = []
|
||||
self.use_cache = True
|
||||
self.__expected_result_count = 0
|
||||
self.__static_result: PolicyResult | None = None
|
||||
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
|
||||
"""Make sure all Policies are their respective classes"""
|
||||
return (
|
||||
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
|
||||
.order_by("order")
|
||||
.iterator()
|
||||
)
|
||||
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
|
||||
|
||||
def _check_policy_type(self, binding: PolicyBinding):
|
||||
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
|
||||
@@ -84,30 +81,66 @@ class PolicyEngine:
|
||||
def _check_cache(self, binding: PolicyBinding):
|
||||
if not self.use_cache:
|
||||
return False
|
||||
before = perf_counter()
|
||||
key = cache_key(binding, self.request)
|
||||
cached_policy = cache.get(key, None)
|
||||
duration = max(perf_counter() - before, 0)
|
||||
if not cached_policy:
|
||||
return False
|
||||
self.logger.debug(
|
||||
"P_ENG: Taking result from cache",
|
||||
binding=binding,
|
||||
cache_key=key,
|
||||
request=self.request,
|
||||
)
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
# It's a bit silly to time this, but
|
||||
with HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=binding.order,
|
||||
binding_target_type=binding.target_type,
|
||||
binding_target_name=binding.target_name,
|
||||
object_pk=str(self.request.obj.pk),
|
||||
object_type=class_to_path(self.request.obj.__class__),
|
||||
mode="cache_retrieve",
|
||||
).observe(duration)
|
||||
# It's a bit silly to time this, but
|
||||
).time():
|
||||
key = cache_key(binding, self.request)
|
||||
cached_policy = cache.get(key, None)
|
||||
if not cached_policy:
|
||||
return False
|
||||
self.logger.debug(
|
||||
"P_ENG: Taking result from cache",
|
||||
binding=binding,
|
||||
cache_key=key,
|
||||
request=self.request,
|
||||
)
|
||||
self.__cached_policies.append(cached_policy)
|
||||
return True
|
||||
|
||||
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
||||
"""Check static bindings if possible"""
|
||||
aggrs = {
|
||||
"total": Count(
|
||||
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
|
||||
),
|
||||
}
|
||||
if self.request.user.pk:
|
||||
all_groups = self.request.user.all_groups()
|
||||
aggrs["passing"] = Count(
|
||||
"pk",
|
||||
filter=Q(
|
||||
Q(
|
||||
Q(user=self.request.user) | Q(group__in=all_groups),
|
||||
negate=False,
|
||||
)
|
||||
| Q(
|
||||
Q(~Q(user=self.request.user), user__isnull=False)
|
||||
| Q(~Q(group__in=all_groups), group__isnull=False),
|
||||
negate=True,
|
||||
),
|
||||
enabled=True,
|
||||
),
|
||||
)
|
||||
matched_bindings = bindings.aggregate(**aggrs)
|
||||
passing = False
|
||||
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
|
||||
# If we didn't find any static bindings, do nothing
|
||||
return
|
||||
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||
# No matching static bindings but at least one is configured -> not passing
|
||||
passing = False
|
||||
self.__static_result = PolicyResult(passing)
|
||||
|
||||
def build(self) -> "PolicyEngine":
|
||||
"""Build wrapper which monitors performance"""
|
||||
with (
|
||||
@@ -123,7 +156,12 @@ class PolicyEngine:
|
||||
span: Span
|
||||
span.set_data("pbm", self.__pbm)
|
||||
span.set_data("request", self.request)
|
||||
for binding in self.iterate_bindings():
|
||||
bindings = self.bindings()
|
||||
policy_bindings = bindings
|
||||
if isinstance(bindings, QuerySet):
|
||||
self.compute_static_bindings(bindings)
|
||||
policy_bindings = [x for x in bindings if x.policy]
|
||||
for binding in policy_bindings:
|
||||
self.__expected_result_count += 1
|
||||
|
||||
self._check_policy_type(binding)
|
||||
@@ -153,10 +191,13 @@ class PolicyEngine:
|
||||
@property
|
||||
def result(self) -> PolicyResult:
|
||||
"""Get policy-checking result"""
|
||||
self.__processes.sort(key=lambda x: x.binding.order)
|
||||
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
|
||||
all_results = list(process_results + self.__cached_policies)
|
||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||
raise AssertionError("Got less results than polices")
|
||||
if self.__static_result:
|
||||
all_results.append(self.__static_result)
|
||||
# No results, no policies attached -> passing
|
||||
if len(all_results) == 0:
|
||||
return PolicyResult(self.empty_result)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from django.http import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.models import Event
|
||||
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
@@ -45,6 +46,10 @@ class PolicyEvaluator(BaseEvaluator):
|
||||
self.set_http_request(request.http_request)
|
||||
self._context["request"] = request
|
||||
self._context["context"] = request.context
|
||||
if request.obj and isinstance(request.obj, Event):
|
||||
self._context["ak_client_ip"] = ip_address(
|
||||
request.obj.client_ip or ClientIPMiddleware.default_ip
|
||||
)
|
||||
|
||||
def set_http_request(self, request: HttpRequest):
|
||||
"""Update context based on http request"""
|
||||
|
||||
@@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
@@ -95,10 +95,13 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
except PolicyException as exc:
|
||||
# Either use passed original exception or whatever we have
|
||||
src_exc = exc.src_exc if exc.src_exc else exc
|
||||
error_string = exception_to_string(src_exc)
|
||||
# Create policy exception event, only when we're not debugging
|
||||
if not self.request.debug:
|
||||
self.create_event(EventAction.POLICY_EXCEPTION, message=error_string)
|
||||
self.create_event(
|
||||
EventAction.POLICY_EXCEPTION,
|
||||
message="Policy failed to execute",
|
||||
exception=exception_to_dict(src_exc),
|
||||
)
|
||||
LOGGER.debug("P_ENG(proc): error, using failure result", exc=src_exc)
|
||||
policy_result = PolicyResult(self.binding.failure_result, str(src_exc))
|
||||
policy_result.source_binding = self.binding
|
||||
@@ -143,5 +146,5 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
try:
|
||||
self.connection.send(self.profiling_wrapper())
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Policy failed to run", exc=exception_to_string(exc))
|
||||
LOGGER.warning("Policy failed to run", exc=exc)
|
||||
self.connection.send(PolicyResult(False, str(exc)))
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""policy engine tests"""
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import connections
|
||||
from django.test import TestCase
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@@ -19,7 +22,7 @@ class TestPolicyEngine(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_policy_cache()
|
||||
self.user = create_test_admin_user()
|
||||
self.user = create_test_user()
|
||||
self.policy_false = DummyPolicy.objects.create(
|
||||
name=generate_id(), result=False, wait_min=0, wait_max=1
|
||||
)
|
||||
@@ -127,3 +130,58 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
|
||||
def test_engine_static_bindings(self):
|
||||
"""Test static bindings"""
|
||||
group_a = Group.objects.create(name=generate_id())
|
||||
group_b = Group.objects.create(name=generate_id())
|
||||
group_b.users.add(self.user)
|
||||
user = create_test_user()
|
||||
|
||||
for case in [
|
||||
{
|
||||
"message": "Group, not member",
|
||||
"binding_args": {"group": group_a},
|
||||
"passing": False,
|
||||
},
|
||||
{
|
||||
"message": "Group, member",
|
||||
"binding_args": {"group": group_b},
|
||||
"passing": True,
|
||||
},
|
||||
{
|
||||
"message": "User, other",
|
||||
"binding_args": {"user": user},
|
||||
"passing": False,
|
||||
},
|
||||
{
|
||||
"message": "User, same",
|
||||
"binding_args": {"user": self.user},
|
||||
"passing": True,
|
||||
},
|
||||
]:
|
||||
with self.subTest():
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
for x in range(1000):
|
||||
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertEqual(engine.result.passing, case["passing"])
|
||||
|
||||
def test_engine_group_complex(self):
|
||||
"""Test more complex group setups"""
|
||||
group_a = Group.objects.create(name=generate_id())
|
||||
group_b = Group.objects.create(name=generate_id(), parent=group_a)
|
||||
user = create_test_user()
|
||||
group_b.users.add(user)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, order=0, group=group_a)
|
||||
engine = PolicyEngine(pbm, user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertTrue(engine.result.passing)
|
||||
|
||||
@@ -29,13 +29,12 @@ class TestPolicyProcess(TestCase):
|
||||
def setUp(self):
|
||||
clear_policy_cache()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(username="policyuser")
|
||||
self.user = User.objects.create_user(username=generate_id())
|
||||
|
||||
def test_group_passing(self):
|
||||
"""Test binding to group"""
|
||||
group = Group.objects.create(name="test-group")
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(self.user)
|
||||
group.save()
|
||||
binding = PolicyBinding(group=group)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@@ -44,8 +43,7 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_group_negative(self):
|
||||
"""Test binding to group"""
|
||||
group = Group.objects.create(name="test-group")
|
||||
group.save()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
binding = PolicyBinding(group=group)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@@ -115,8 +113,10 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_exception(self):
|
||||
"""Test policy execution"""
|
||||
policy = Policy.objects.create(name="test-execution")
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
policy = Policy.objects.create(name=generate_id())
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
response = PolicyProcess(binding, request, None).execute()
|
||||
@@ -125,13 +125,15 @@ class TestPolicyProcess(TestCase):
|
||||
def test_execution_logging(self):
|
||||
"""Test policy execution creates event"""
|
||||
policy = DummyPolicy.objects.create(
|
||||
name="test-execution-logging",
|
||||
name=generate_id(),
|
||||
result=False,
|
||||
wait_min=0,
|
||||
wait_max=1,
|
||||
execution_logging=True,
|
||||
)
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
|
||||
http_request.user = self.user
|
||||
@@ -186,13 +188,15 @@ class TestPolicyProcess(TestCase):
|
||||
def test_execution_logging_anonymous(self):
|
||||
"""Test policy execution creates event with anonymous user"""
|
||||
policy = DummyPolicy.objects.create(
|
||||
name="test-execution-logging-anon",
|
||||
name=generate_id(),
|
||||
result=False,
|
||||
wait_min=0,
|
||||
wait_max=1,
|
||||
execution_logging=True,
|
||||
)
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
user = AnonymousUser()
|
||||
|
||||
@@ -219,9 +223,9 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_raises(self):
|
||||
"""Test policy that raises error"""
|
||||
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
|
||||
policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
|
||||
binding = PolicyBinding(
|
||||
policy=policy_raises, target=Application.objects.create(name="test")
|
||||
policy=policy_raises, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@@ -237,4 +241,4 @@ class TestPolicyProcess(TestCase):
|
||||
self.assertEqual(len(events), 1)
|
||||
event = events.first()
|
||||
self.assertEqual(event.user["username"], self.user.username)
|
||||
self.assertIn("division by zero", event.context["message"])
|
||||
self.assertIn("Policy failed to execute", event.context["message"])
|
||||
|
||||
@@ -15,12 +15,14 @@ class OAuth2Error(SentryIgnoredException):
|
||||
|
||||
error: str
|
||||
description: str
|
||||
cause: str | None = None
|
||||
|
||||
def create_dict(self):
|
||||
def create_dict(self, request: HttpRequest):
|
||||
"""Return error as dict for JSON Rendering"""
|
||||
return {
|
||||
"error": self.error,
|
||||
"error_description": self.description,
|
||||
"request_id": request.request_id,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -31,9 +33,15 @@ class OAuth2Error(SentryIgnoredException):
|
||||
return Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=message or self.description,
|
||||
cause=self.cause,
|
||||
error=self.error,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def with_cause(self, cause: str):
|
||||
self.cause = cause
|
||||
return self
|
||||
|
||||
|
||||
class RedirectUriError(OAuth2Error):
|
||||
"""The request fails due to a missing, invalid, or mismatching
|
||||
@@ -243,13 +251,14 @@ class TokenRevocationError(OAuth2Error):
|
||||
self.description = self.errors[error]
|
||||
|
||||
|
||||
class DeviceCodeError(OAuth2Error):
|
||||
class DeviceCodeError(TokenError):
|
||||
"""
|
||||
Device-code flow errors
|
||||
See https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||
Can also use codes form TokenError
|
||||
"""
|
||||
|
||||
errors = {
|
||||
errors = TokenError.errors | {
|
||||
"authorization_pending": (
|
||||
"The authorization request is still pending as the end user hasn't "
|
||||
"yet completed the user-interaction steps"
|
||||
@@ -261,10 +270,15 @@ class DeviceCodeError(OAuth2Error):
|
||||
"authorization request but SHOULD wait for user interaction before "
|
||||
"restarting to avoid unnecessary polling."
|
||||
),
|
||||
"slow_down": (
|
||||
'A variant of "authorization_pending", the authorization request is'
|
||||
"still pending and polling should continue, but the interval MUST"
|
||||
"be increased by 5 seconds for this and all subsequent requests."
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, error: str):
|
||||
super().__init__()
|
||||
super().__init__(error)
|
||||
self.error = error
|
||||
self.description = self.errors[error]
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.constants import TOKEN_TYPE
|
||||
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
|
||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
@@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||
|
||||
def test_invalid_client_id(self):
|
||||
"""Test invalid client ID"""
|
||||
@@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "request_not_supported")
|
||||
|
||||
def test_invalid_redirect_uri(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
def test_invalid_redirect_uri_missing(self):
|
||||
"""test missing redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
|
||||
|
||||
def test_invalid_redirect_uri(self):
|
||||
"""test invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_blocked_redirect_uri(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
@@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||
|
||||
def test_invalid_redirect_uri_empty(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
@@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_redirect_uri_invalid_regex(self):
|
||||
"""test missing/invalid redirect URI (invalid regex)"""
|
||||
@@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_empty_redirect_uri(self):
|
||||
"""test empty redirect URI (configure in provider)"""
|
||||
def test_redirect_uri_regex(self):
|
||||
"""test valid redirect URI (regex)"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"redirect_uri": "http://foo.bar.baz",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
@@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
GrantTypes.IMPLICIT,
|
||||
)
|
||||
# Implicit without openid scope
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
||||
)
|
||||
with self.assertRaises(AuthorizeError):
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||
|
||||
def test_full_code(self):
|
||||
"""Test full authorization"""
|
||||
@@ -613,3 +619,54 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_openid_missing_invalid(self):
|
||||
"""test request requiring an OpenID scope to be set"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "id_token",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"scope": "",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "scope_openid_missing")
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_offline_access_invalid(self):
|
||||
"""test request for offline_access with invalid response type"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-offline_access",
|
||||
]
|
||||
)
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "id_token",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
|
||||
"nonce": generate_id(),
|
||||
},
|
||||
)
|
||||
parsed = OAuthAuthorizationParams.from_request(request)
|
||||
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
|
||||
|
||||
@@ -68,7 +68,11 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_no_provider(self):
|
||||
@@ -87,7 +91,11 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
@@ -110,7 +118,11 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_incorrect_scopes(self):
|
||||
|
||||
@@ -68,7 +68,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_wrong_token(self):
|
||||
@@ -85,7 +89,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_no_provider(self):
|
||||
@@ -104,7 +112,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
@@ -127,7 +139,11 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_successful(self):
|
||||
|
||||
@@ -68,7 +68,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_wrong_token(self):
|
||||
@@ -86,7 +90,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_no_provider(self):
|
||||
@@ -106,7 +114,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
@@ -130,7 +142,11 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_successful(self):
|
||||
|
||||
@@ -80,6 +80,7 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
"revoked, does not match the redirection URI used in the authorization "
|
||||
"request, or was issued to another client"
|
||||
),
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
@@ -136,6 +137,7 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
"revoked, does not match the redirection URI used in the authorization "
|
||||
"request, or was issued to another client"
|
||||
),
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
@@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
|
||||
allowed_redirect_urls = self.provider.redirect_uris
|
||||
if not self.redirect_uri:
|
||||
LOGGER.warning("Missing redirect uri.")
|
||||
raise RedirectUriError("", allowed_redirect_urls)
|
||||
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||
|
||||
if len(allowed_redirect_urls) < 1:
|
||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||
@@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
|
||||
provider=self.provider,
|
||||
)
|
||||
if not match_found:
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||
"redirect_uri_no_match"
|
||||
)
|
||||
# Check against forbidden schemes
|
||||
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||
"redirect_uri_forbidden_scheme"
|
||||
)
|
||||
|
||||
def check_scope(self, github_compat=False):
|
||||
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
||||
@@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
|
||||
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||
):
|
||||
LOGGER.warning("Missing 'openid' scope.")
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
||||
).with_cause("scope_openid_missing")
|
||||
if SCOPE_OFFLINE_ACCESS in self.scope:
|
||||
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||
# Don't explicitly request consent with offline_access, as the spec allows for
|
||||
@@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
|
||||
return
|
||||
if not self.nonce:
|
||||
LOGGER.warning("Missing nonce for OpenID Request")
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||
).with_cause("nonce_missing")
|
||||
|
||||
def check_code_challenge(self):
|
||||
"""PKCE validation of the transformation method."""
|
||||
@@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
||||
self.request, github_compat=self.github_compat
|
||||
)
|
||||
except AuthorizeError as error:
|
||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
|
||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
|
||||
raise RequestValidationError(error.get_response(self.request)) from None
|
||||
except OAuth2Error as error:
|
||||
LOGGER.warning(error.description)
|
||||
LOGGER.warning(error.description, cause=error.cause)
|
||||
raise RequestValidationError(
|
||||
bad_request_message(self.request, error.description, title=error.error)
|
||||
) from None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, JsonResponse
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.urls import reverse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.utils.timezone import now
|
||||
@@ -14,7 +14,9 @@ from structlog.stdlib import get_logger
|
||||
from authentik.core.models import Application
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.utils import TokenResponse
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -28,38 +30,36 @@ class DeviceView(View):
|
||||
provider: OAuth2Provider
|
||||
scopes: list[str] = []
|
||||
|
||||
def parse_request(self) -> HttpResponse | None:
|
||||
def parse_request(self):
|
||||
"""Parse incoming request"""
|
||||
client_id = self.request.POST.get("client_id", None)
|
||||
if not client_id:
|
||||
return HttpResponseBadRequest()
|
||||
provider = OAuth2Provider.objects.filter(
|
||||
client_id=client_id,
|
||||
).first()
|
||||
raise DeviceCodeError("invalid_client")
|
||||
provider = OAuth2Provider.objects.filter(client_id=client_id).first()
|
||||
if not provider:
|
||||
return HttpResponseBadRequest()
|
||||
raise DeviceCodeError("invalid_client")
|
||||
try:
|
||||
_ = provider.application
|
||||
except Application.DoesNotExist:
|
||||
return HttpResponseBadRequest()
|
||||
raise DeviceCodeError("invalid_client") from None
|
||||
self.provider = provider
|
||||
self.client_id = client_id
|
||||
self.scopes = self.request.POST.get("scope", "").split(" ")
|
||||
return None
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
throttle = AnonRateThrottle()
|
||||
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
||||
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
||||
if not throttle.allow_request(request, self):
|
||||
return HttpResponse(status=429)
|
||||
return TokenResponse(DeviceCodeError("slow_down").create_dict(request), status=429)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def post(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Generate device token"""
|
||||
resp = self.parse_request()
|
||||
if resp:
|
||||
return resp
|
||||
try:
|
||||
self.parse_request()
|
||||
except DeviceCodeError as exc:
|
||||
return TokenResponse(exc.create_dict(request), status=400)
|
||||
until = timedelta_from_string(self.provider.access_code_validity)
|
||||
token: DeviceToken = DeviceToken.objects.create(
|
||||
expires=now() + until, provider=self.provider, _scope=" ".join(self.scopes)
|
||||
@@ -67,7 +67,7 @@ class DeviceView(View):
|
||||
device_url = self.request.build_absolute_uri(
|
||||
reverse("authentik_providers_oauth2_root:device-login")
|
||||
)
|
||||
return JsonResponse(
|
||||
return TokenResponse(
|
||||
{
|
||||
"device_code": token.device_code,
|
||||
"verification_uri": device_url,
|
||||
|
||||
@@ -598,9 +598,9 @@ class TokenView(View):
|
||||
return TokenResponse(self.create_device_code_response())
|
||||
raise TokenError("unsupported_grant_type")
|
||||
except (TokenError, DeviceCodeError) as error:
|
||||
return TokenResponse(error.create_dict(), status=400)
|
||||
return TokenResponse(error.create_dict(request), status=400)
|
||||
except UserAuthError as error:
|
||||
return TokenResponse(error.create_dict(), status=403)
|
||||
return TokenResponse(error.create_dict(request), status=403)
|
||||
|
||||
def create_code_response(self) -> dict[str, Any]:
|
||||
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1"""
|
||||
|
||||
@@ -65,7 +65,7 @@ class TokenRevokeView(View):
|
||||
|
||||
return TokenResponse(data={}, status=200)
|
||||
except TokenRevocationError as exc:
|
||||
return TokenResponse(exc.create_dict(), status=401)
|
||||
return TokenResponse(exc.create_dict(request), status=401)
|
||||
except Http404:
|
||||
# Token not found should return a HTTP 200
|
||||
# https://datatracker.ietf.org/doc/html/rfc7009#section-2.2
|
||||
|
||||
@@ -102,6 +102,7 @@ class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
|
||||
# Buffer sizes for large headers with JWTs
|
||||
"nginx.ingress.kubernetes.io/proxy-buffers-number": "4",
|
||||
"nginx.ingress.kubernetes.io/proxy-buffer-size": "16k",
|
||||
"nginx.ingress.kubernetes.io/proxy-busy-buffers-size": "32k",
|
||||
# Enable TLS in traefik
|
||||
"traefik.ingress.kubernetes.io/router.tls": "true",
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ from authentik.core.models import Application
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.exceptions import ControlFlowException
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.types import PolicyResult
|
||||
@@ -142,9 +141,9 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
message="Failed to evaluate property-mapping",
|
||||
mapping=exc.mapping,
|
||||
).save()
|
||||
).with_exception(exc).save()
|
||||
return None
|
||||
return b64encode(packet.RequestPacket()).decode()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
||||
from pydanticscim.group import Group as BaseGroup
|
||||
from pydanticscim.responses import PatchOperation as BasePatchOperation
|
||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||
@@ -12,19 +12,95 @@ from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import (
|
||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||
)
|
||||
from pydanticscim.user import AddressKind
|
||||
from pydanticscim.user import User as BaseUser
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
|
||||
|
||||
class Address(BaseModel):
|
||||
formatted: str | None = Field(
|
||||
None,
|
||||
description="The full mailing address, formatted for display "
|
||||
"or use with a mailing label. This attribute MAY contain newlines.",
|
||||
)
|
||||
streetAddress: str | None = Field(
|
||||
None,
|
||||
description="The full street address component, which may "
|
||||
"include house number, street name, P.O. box, and multi-line "
|
||||
"extended street address information. This attribute MAY contain newlines.",
|
||||
)
|
||||
locality: str | None = Field(None, description="The city or locality component.")
|
||||
region: str | None = Field(None, description="The state or region component.")
|
||||
postalCode: str | None = Field(None, description="The zip code or postal code component.")
|
||||
country: str | None = Field(None, description="The country name component.")
|
||||
type: AddressKind | None = Field(
|
||||
None,
|
||||
description="A label indicating the attribute's function, e.g., 'work' or 'home'.",
|
||||
)
|
||||
primary: bool | None = None
|
||||
|
||||
|
||||
class Manager(BaseModel):
|
||||
value: str | None = Field(
|
||||
None,
|
||||
description="The id of the SCIM resource representingthe User's manager. REQUIRED.",
|
||||
)
|
||||
ref: AnyUrl | None = Field(
|
||||
None,
|
||||
alias="$ref",
|
||||
description="The URI of the SCIM resource representing the User's manager. REQUIRED.",
|
||||
)
|
||||
displayName: str | None = Field(
|
||||
None,
|
||||
description="The displayName of the User's manager. OPTIONAL and READ-ONLY.",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseUser(BaseModel):
|
||||
employeeNumber: str | None = Field(
|
||||
None,
|
||||
description="Numeric or alphanumeric identifier assigned to a person, "
|
||||
"typically based on order of hire or association with anorganization.",
|
||||
)
|
||||
costCenter: str | None = Field(None, description="Identifies the name of a cost center.")
|
||||
organization: str | None = Field(None, description="Identifies the name of an organization.")
|
||||
division: str | None = Field(None, description="Identifies the name of a division.")
|
||||
department: str | None = Field(
|
||||
None,
|
||||
description="Numeric or alphanumeric identifier assigned to a person,"
|
||||
" typically based on order of hire or association with anorganization.",
|
||||
)
|
||||
manager: Manager | None = Field(
|
||||
None,
|
||||
description="The User's manager. A complex type that optionally allows "
|
||||
"service providers to represent organizational hierarchy by referencing"
|
||||
" the 'id' attribute of another User.",
|
||||
)
|
||||
|
||||
|
||||
class User(BaseUser):
|
||||
"""Modified User schema with added externalId field"""
|
||||
|
||||
model_config = ConfigDict(serialize_by_alias=True)
|
||||
|
||||
id: str | int | None = None
|
||||
schemas: list[str] = [SCIM_USER_SCHEMA]
|
||||
externalId: str | None = None
|
||||
meta: dict | None = None
|
||||
addresses: list[Address] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"A physical mailing address for this User. Canonical type "
|
||||
"values of 'work', 'home', and 'other'."
|
||||
),
|
||||
)
|
||||
enterprise_user: EnterpriseUser | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
serialization_alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class Group(BaseGroup):
|
||||
@@ -92,7 +168,7 @@ class PatchOperation(BasePatchOperation):
|
||||
"""PatchOperation with optional path"""
|
||||
|
||||
op: PatchOp
|
||||
path: str | None
|
||||
path: str | None = None
|
||||
|
||||
|
||||
class SCIMError(BaseSCIMError):
|
||||
|
||||
@@ -28,7 +28,6 @@ from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.lib.sentry import should_ignore_exception
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
# set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||
@@ -83,8 +82,8 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
|
||||
CTX_TASK_ID.set(...)
|
||||
if not should_ignore_exception(exception):
|
||||
Event.new(
|
||||
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
|
||||
).save()
|
||||
EventAction.SYSTEM_EXCEPTION, message="Failed to execute task", task_id=task_id
|
||||
).with_exception(exception).save()
|
||||
|
||||
|
||||
def _get_startup_tasks_default_tenant() -> list[Callable]:
|
||||
|
||||
@@ -49,6 +49,8 @@ class ReadyView(View):
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
try:
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
except OperationalError: # pragma: no cover
|
||||
return HttpResponse(status=503)
|
||||
|
||||
@@ -156,16 +156,17 @@ SPECTACULAR_SETTINGS = {
|
||||
},
|
||||
"ENUM_NAME_OVERRIDES": {
|
||||
"CountryCodeEnum": "django_countries.countries",
|
||||
"DeviceClassesEnum": "authentik.stages.authenticator_validate.models.DeviceClasses",
|
||||
"EventActions": "authentik.events.models.EventAction",
|
||||
"FlowDesignationEnum": "authentik.flows.models.FlowDesignation",
|
||||
"FlowLayoutEnum": "authentik.flows.models.FlowLayout",
|
||||
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
||||
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
||||
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
|
||||
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
||||
},
|
||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
|
||||
|
||||
@@ -4,11 +4,11 @@ from pathlib import Path
|
||||
from secrets import token_urlsafe
|
||||
from tempfile import gettempdir
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test import TransactionTestCase
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
class TestRoot(TestCase):
|
||||
class TestRoot(TransactionTestCase):
|
||||
"""Test root application"""
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -8,7 +8,6 @@ from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.kerberos.models import KerberosSource
|
||||
from authentik.sources.kerberos.sync import KerberosSync
|
||||
@@ -64,5 +63,5 @@ def kerberos_sync_single(self, source_pk: str):
|
||||
syncer.sync()
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages)
|
||||
except StopSync as exc:
|
||||
LOGGER.warning(exception_to_string(exc))
|
||||
LOGGER.warning("Error syncing kerberos", exc=exc, source=source)
|
||||
self.set_error(exc)
|
||||
|
||||
@@ -12,7 +12,6 @@ from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
@@ -149,5 +148,5 @@ def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key:
|
||||
cache.delete(page_cache_key)
|
||||
except (LDAPException, StopSync) as exc:
|
||||
# No explicit event is created here as .set_status with an error will do that
|
||||
LOGGER.warning(exception_to_string(exc))
|
||||
LOGGER.warning("Failed to sync LDAP", exc=exc, source=source)
|
||||
self.set_error(exc)
|
||||
|
||||
@@ -10,6 +10,7 @@ AUTHENTIK_SOURCES_OAUTH_TYPES = [
|
||||
"authentik.sources.oauth.types.apple",
|
||||
"authentik.sources.oauth.types.azure_ad",
|
||||
"authentik.sources.oauth.types.discord",
|
||||
"authentik.sources.oauth.types.entra_id",
|
||||
"authentik.sources.oauth.types.facebook",
|
||||
"authentik.sources.oauth.types.github",
|
||||
"authentik.sources.oauth.types.gitlab",
|
||||
|
||||
@@ -232,7 +232,7 @@ class GoogleOAuthSource(CreatableType, OAuthSource):
|
||||
|
||||
|
||||
class AzureADOAuthSource(CreatableType, OAuthSource):
|
||||
"""Social Login using Azure AD."""
|
||||
"""(Deprecated) Social Login using Azure AD."""
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
@@ -240,6 +240,17 @@ class AzureADOAuthSource(CreatableType, OAuthSource):
|
||||
verbose_name_plural = _("Azure AD OAuth Sources")
|
||||
|
||||
|
||||
# TODO: When removing this, add a migration for OAuthSource that sets
|
||||
# provider_type to `entraid` if it is currently `azuread`
|
||||
class EntraIDOAuthSource(CreatableType, OAuthSource):
|
||||
"""Social Login using Entra ID."""
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
verbose_name = _("Entra ID OAuth Source")
|
||||
verbose_name_plural = _("Entra ID OAuth Sources")
|
||||
|
||||
|
||||
class OpenIDConnectOAuthSource(CreatableType, OAuthSource):
|
||||
"""Login using a Generic OpenID-Connect compliant provider."""
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""azure ad Type tests"""
|
||||
"""Entra ID Type tests"""
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback, AzureADType
|
||||
from authentik.sources.oauth.types.entra_id import EntraIDOAuthCallback, EntraIDType
|
||||
|
||||
# https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2
|
||||
AAD_USER = {
|
||||
EID_USER = {
|
||||
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users/$entity",
|
||||
"@odata.id": (
|
||||
"https://graph.microsoft.com/v2/7ce9b89e-646a-41d2-9fa6-8371c6a8423d/"
|
||||
@@ -41,11 +41,11 @@ class TestTypeAzureAD(TestCase):
|
||||
|
||||
def test_enroll_context(self):
|
||||
"""Test azure_ad Enrollment context"""
|
||||
ak_context = AzureADType().get_base_user_properties(source=self.source, info=AAD_USER)
|
||||
self.assertEqual(ak_context["username"], AAD_USER["userPrincipalName"])
|
||||
self.assertEqual(ak_context["email"], AAD_USER["mail"])
|
||||
self.assertEqual(ak_context["name"], AAD_USER["displayName"])
|
||||
ak_context = EntraIDType().get_base_user_properties(source=self.source, info=EID_USER)
|
||||
self.assertEqual(ak_context["username"], EID_USER["userPrincipalName"])
|
||||
self.assertEqual(ak_context["email"], EID_USER["mail"])
|
||||
self.assertEqual(ak_context["name"], EID_USER["displayName"])
|
||||
|
||||
def test_user_id(self):
|
||||
"""Test azure AD user ID"""
|
||||
self.assertEqual(AzureADOAuthCallback().get_user_id(AAD_USER), AAD_USER["id"])
|
||||
"""Test Entra ID user ID"""
|
||||
self.assertEqual(EntraIDOAuthCallback().get_user_id(EID_USER), EID_USER["id"])
|
||||
@@ -1,105 +1,17 @@
|
||||
"""AzureAD OAuth2 Views"""
|
||||
|
||||
from typing import Any
|
||||
from authentik.sources.oauth.types.entra_id import EntraIDType
|
||||
from authentik.sources.oauth.types.registry import registry
|
||||
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class AzureADOAuthRedirect(OAuthRedirect):
|
||||
"""Azure AD OAuth2 Redirect"""
|
||||
|
||||
def get_additional_parameters(self, source): # pragma: no cover
|
||||
return {
|
||||
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
|
||||
}
|
||||
|
||||
|
||||
class AzureADClient(UserprofileHeaderAuthClient):
|
||||
"""Fetch AzureAD group information"""
|
||||
|
||||
def get_profile_info(self, token):
|
||||
profile_data = super().get_profile_info(token)
|
||||
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
|
||||
return profile_data
|
||||
group_response = self.session.request(
|
||||
"get",
|
||||
"https://graph.microsoft.com/v1.0/me/memberOf",
|
||||
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
||||
)
|
||||
try:
|
||||
group_response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning(
|
||||
"Unable to fetch user profile",
|
||||
exc=exc,
|
||||
response=exc.response.text if exc.response else str(exc),
|
||||
)
|
||||
return None
|
||||
profile_data["raw_groups"] = group_response.json()
|
||||
return profile_data
|
||||
|
||||
|
||||
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
|
||||
"""AzureAD OAuth2 Callback"""
|
||||
|
||||
client_class = AzureADClient
|
||||
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
# Default try to get `id` for the Graph API endpoint
|
||||
# fallback to OpenID logic in case the profile URL was changed
|
||||
return info.get("id", super().get_user_id(info))
|
||||
# TODO: When removing this, add a migration for OAuthSource that sets
|
||||
# provider_type to `entraid` if it is currently `azuread`
|
||||
|
||||
|
||||
@registry.register()
|
||||
class AzureADType(SourceType):
|
||||
class AzureADType(EntraIDType):
|
||||
"""Azure AD Type definition"""
|
||||
|
||||
callback_view = AzureADOAuthCallback
|
||||
redirect_view = AzureADOAuthRedirect
|
||||
verbose_name = "Azure AD"
|
||||
name = "azuread"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
|
||||
profile_url = "https://graph.microsoft.com/v1.0/me"
|
||||
oidc_well_known_url = (
|
||||
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
|
||||
)
|
||||
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
||||
|
||||
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||
|
||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||
# Format group info
|
||||
groups = []
|
||||
group_id_dict = {}
|
||||
for group in info.get("raw_groups", {}).get("value", []):
|
||||
if group["@odata.type"] != "#microsoft.graph.group":
|
||||
continue
|
||||
groups.append(group["id"])
|
||||
group_id_dict[group["id"]] = group
|
||||
info["raw_groups"] = group_id_dict
|
||||
return {
|
||||
"username": info.get("userPrincipalName"),
|
||||
"email": mail,
|
||||
"name": info.get("displayName"),
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
def get_base_group_properties(self, source, group_id, **kwargs):
|
||||
raw_group = kwargs["info"]["raw_groups"][group_id]
|
||||
return {
|
||||
"name": raw_group["displayName"],
|
||||
}
|
||||
|
||||
102
authentik/sources/oauth/types/entra_id.py
Normal file
102
authentik/sources/oauth/types/entra_id.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""EntraID OAuth2 Views"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class EntraIDOAuthRedirect(OAuthRedirect):
|
||||
"""Entra ID OAuth2 Redirect"""
|
||||
|
||||
def get_additional_parameters(self, source): # pragma: no cover
|
||||
return {
|
||||
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
|
||||
}
|
||||
|
||||
|
||||
class EntraIDClient(UserprofileHeaderAuthClient):
|
||||
"""Fetch EntraID group information"""
|
||||
|
||||
def get_profile_info(self, token):
|
||||
profile_data = super().get_profile_info(token)
|
||||
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
|
||||
return profile_data
|
||||
group_response = self.session.request(
|
||||
"get",
|
||||
"https://graph.microsoft.com/v1.0/me/memberOf",
|
||||
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
||||
)
|
||||
try:
|
||||
group_response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning(
|
||||
"Unable to fetch user profile",
|
||||
exc=exc,
|
||||
response=exc.response.text if exc.response else str(exc),
|
||||
)
|
||||
return None
|
||||
profile_data["raw_groups"] = group_response.json()
|
||||
return profile_data
|
||||
|
||||
|
||||
class EntraIDOAuthCallback(OpenIDConnectOAuth2Callback):
|
||||
"""EntraID OAuth2 Callback"""
|
||||
|
||||
client_class = EntraIDClient
|
||||
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
# Default try to get `id` for the Graph API endpoint
|
||||
# fallback to OpenID logic in case the profile URL was changed
|
||||
return info.get("id", super().get_user_id(info))
|
||||
|
||||
|
||||
@registry.register()
|
||||
class EntraIDType(SourceType):
|
||||
"""Entra ID Type definition"""
|
||||
|
||||
callback_view = EntraIDOAuthCallback
|
||||
redirect_view = EntraIDOAuthRedirect
|
||||
verbose_name = "Entra ID"
|
||||
name = "entraid"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
|
||||
profile_url = "https://graph.microsoft.com/v1.0/me"
|
||||
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
||||
|
||||
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||
|
||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||
# Format group info
|
||||
groups = []
|
||||
group_id_dict = {}
|
||||
for group in info.get("raw_groups", {}).get("value", []):
|
||||
if group["@odata.type"] != "#microsoft.graph.group":
|
||||
continue
|
||||
groups.append(group["id"])
|
||||
group_id_dict[group["id"]] = group
|
||||
info["raw_groups"] = group_id_dict
|
||||
return {
|
||||
"username": info.get("userPrincipalName"),
|
||||
"email": mail,
|
||||
"name": info.get("displayName"),
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
def get_base_group_properties(self, source, group_id, **kwargs):
|
||||
raw_group = kwargs["info"]["raw_groups"][group_id]
|
||||
return {
|
||||
"name": raw_group["displayName"],
|
||||
}
|
||||
@@ -18,6 +18,7 @@ class SCIMSourceGroupSerializer(SourceSerializer):
|
||||
model = SCIMSourceGroup
|
||||
fields = [
|
||||
"id",
|
||||
"external_id",
|
||||
"group",
|
||||
"group_obj",
|
||||
"source",
|
||||
@@ -31,5 +32,5 @@ class SCIMSourceGroupViewSet(UsedByMixin, ModelViewSet):
|
||||
queryset = SCIMSourceGroup.objects.all().select_related("group")
|
||||
serializer_class = SCIMSourceGroupSerializer
|
||||
filterset_fields = ["source__slug", "group__name", "group__group_uuid"]
|
||||
search_fields = ["source__slug", "group__name", "attributes"]
|
||||
search_fields = ["source__slug", "group__name", "attributes", "external_id"]
|
||||
ordering = ["group__name"]
|
||||
|
||||
@@ -18,6 +18,7 @@ class SCIMSourceUserSerializer(SourceSerializer):
|
||||
model = SCIMSourceUser
|
||||
fields = [
|
||||
"id",
|
||||
"external_id",
|
||||
"user",
|
||||
"user_obj",
|
||||
"source",
|
||||
@@ -31,5 +32,5 @@ class SCIMSourceUserViewSet(UsedByMixin, ModelViewSet):
|
||||
queryset = SCIMSourceUser.objects.all().select_related("user")
|
||||
serializer_class = SCIMSourceUserSerializer
|
||||
filterset_fields = ["source__slug", "user__username", "user__id"]
|
||||
search_fields = ["source__slug", "user__username", "attributes"]
|
||||
search_fields = ["source__slug", "user__username", "attributes", "user__uuid", "external_id"]
|
||||
ordering = ["user__username"]
|
||||
|
||||
4
authentik/sources/scim/constants.py
Normal file
4
authentik/sources/scim/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
SCIM_URN_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_URN_GROUP = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_URN_USER = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_URN_USER_ENTERPRISE = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
@@ -1,8 +0,0 @@
|
||||
"""SCIM Errors"""
|
||||
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
|
||||
|
||||
class PatchError(SentryIgnoredException):
|
||||
"""Error raised within an atomic block when an error happened
|
||||
so nothing is saved"""
|
||||
@@ -0,0 +1,98 @@
|
||||
# Generated by Django 5.1.11 on 2025-07-13 01:07
|
||||
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
from django.apps.registry import Apps
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def migrate_ext_id(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
SCIMSourceUser = apps.get_model("authentik_sources_scim", "SCIMSourceUser")
|
||||
SCIMSourceGroup = apps.get_model("authentik_sources_scim", "SCIMSourceGroup")
|
||||
db_alias = schema_editor.connection.alias
|
||||
for user in SCIMSourceUser.objects.using(db_alias).all():
|
||||
user.external_id = user.id
|
||||
user.save(update_fields=["external_id"])
|
||||
for group in SCIMSourceGroup.objects.using(db_alias).all():
|
||||
group.external_id = group.id
|
||||
group.save(update_fields=["external_id"])
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_sources_scim", "0002_scimsourcepropertymapping"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourcegroup",
|
||||
unique_together=set(),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourceuser",
|
||||
unique_together=set(),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourcegroup",
|
||||
name="external_id",
|
||||
field=models.TextField(default=None, null=True),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourceuser",
|
||||
name="external_id",
|
||||
field=models.TextField(default=None, null=True),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourcegroup",
|
||||
unique_together={("external_id", "source")},
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourceuser",
|
||||
unique_together={("external_id", "source")},
|
||||
),
|
||||
migrations.RunPython(migrate_ext_id, migrations.RunPython.noop),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourcegroup",
|
||||
name="external_id",
|
||||
field=models.TextField(),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourceuser",
|
||||
name="external_id",
|
||||
field=models.TextField(),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="scimsourcegroup",
|
||||
index=models.Index(fields=["external_id"], name="authentik_s_externa_05e346_idx"),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="scimsourceuser",
|
||||
index=models.Index(fields=["external_id"], name="authentik_s_externa_4bd760_idx"),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourcegroup",
|
||||
name="id",
|
||||
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourceuser",
|
||||
name="id",
|
||||
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourcegroup",
|
||||
name="last_update",
|
||||
field=models.DateTimeField(auto_now=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourceuser",
|
||||
name="last_update",
|
||||
field=models.DateTimeField(auto_now=True),
|
||||
),
|
||||
]
|
||||
@@ -1,6 +1,7 @@
|
||||
"""SCIM Source"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
from django.templatetags.static import static
|
||||
@@ -103,10 +104,12 @@ class SCIMSourcePropertyMapping(PropertyMapping):
|
||||
class SCIMSourceUser(SerializerModel):
|
||||
"""Mapping of a user and source to a SCIM user ID"""
|
||||
|
||||
id = models.TextField(primary_key=True)
|
||||
id = models.TextField(primary_key=True, default=uuid4)
|
||||
external_id = models.TextField()
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
last_update = models.DateTimeField(auto_now=True)
|
||||
|
||||
@property
|
||||
def serializer(self) -> BaseSerializer:
|
||||
@@ -115,7 +118,10 @@ class SCIMSourceUser(SerializerModel):
|
||||
return SCIMSourceUserSerializer
|
||||
|
||||
class Meta:
|
||||
unique_together = (("id", "user", "source"),)
|
||||
unique_together = (("external_id", "source"),)
|
||||
indexes = [
|
||||
models.Index(fields=["external_id"]),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM User {self.user_id} to {self.source_id}"
|
||||
@@ -124,10 +130,12 @@ class SCIMSourceUser(SerializerModel):
|
||||
class SCIMSourceGroup(SerializerModel):
|
||||
"""Mapping of a group and source to a SCIM user ID"""
|
||||
|
||||
id = models.TextField(primary_key=True)
|
||||
id = models.TextField(primary_key=True, default=uuid4)
|
||||
external_id = models.TextField()
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
last_update = models.DateTimeField(auto_now=True)
|
||||
|
||||
@property
|
||||
def serializer(self) -> BaseSerializer:
|
||||
@@ -136,7 +144,10 @@ class SCIMSourceGroup(SerializerModel):
|
||||
return SCIMSourceGroupSerializer
|
||||
|
||||
class Meta:
|
||||
unique_together = (("id", "group", "source"),)
|
||||
unique_together = (("external_id", "source"),)
|
||||
indexes = [
|
||||
models.Index(fields=["external_id"]),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM Group {self.group_id} to {self.source_id}"
|
||||
|
||||
0
authentik/sources/scim/patch/__init__.py
Normal file
0
authentik/sources/scim/patch/__init__.py
Normal file
180
authentik/sources/scim/patch/lexer.py
Normal file
180
authentik/sources/scim/patch/lexer.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from authentik.sources.scim.constants import (
|
||||
SCIM_URN_GROUP,
|
||||
SCIM_URN_SCHEMA,
|
||||
SCIM_URN_USER,
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
)
|
||||
|
||||
|
||||
# Token types for SCIM path parsing
|
||||
class TokenType(Enum):
|
||||
ATTRIBUTE = "ATTRIBUTE"
|
||||
DOT = "DOT"
|
||||
LBRACKET = "LBRACKET"
|
||||
RBRACKET = "RBRACKET"
|
||||
LPAREN = "LPAREN"
|
||||
RPAREN = "RPAREN"
|
||||
STRING = "STRING"
|
||||
NUMBER = "NUMBER"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
NULL = "NULL"
|
||||
OPERATOR = "OPERATOR"
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
NOT = "NOT"
|
||||
EOF = "EOF"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: str
|
||||
position: int = 0
|
||||
|
||||
|
||||
class SCIMPathLexer:
|
||||
"""Lexer for SCIM paths and filter expressions"""
|
||||
|
||||
OPERATORS = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.schema_urns = [
|
||||
SCIM_URN_SCHEMA,
|
||||
SCIM_URN_GROUP,
|
||||
SCIM_URN_USER,
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
]
|
||||
self.text = text
|
||||
self.pos = 0
|
||||
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
|
||||
|
||||
def advance(self):
|
||||
"""Move to next character"""
|
||||
self.pos += 1
|
||||
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
|
||||
|
||||
def skip_whitespace(self):
|
||||
"""Skip whitespace characters"""
|
||||
while self.current_char and self.current_char.isspace():
|
||||
self.advance()
|
||||
|
||||
def read_string(self, quote_char):
|
||||
"""Read a quoted string"""
|
||||
value = ""
|
||||
self.advance() # Skip opening quote
|
||||
|
||||
while self.current_char and self.current_char != quote_char:
|
||||
if self.current_char == "\\":
|
||||
self.advance()
|
||||
if self.current_char:
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
else:
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
|
||||
if self.current_char == quote_char:
|
||||
self.advance() # Skip closing quote
|
||||
|
||||
return value
|
||||
|
||||
def read_number(self):
|
||||
"""Read a number (integer or float)"""
|
||||
value = ""
|
||||
while self.current_char and (self.current_char.isdigit() or self.current_char == "."):
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
return value
|
||||
|
||||
def read_identifier(self):
|
||||
"""Read an identifier (attribute name or operator) - supports URN format"""
|
||||
value = ""
|
||||
while self.current_char and (self.current_char.isalnum() or self.current_char in "_-:"):
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
# If the identifier value so far is a schema URN, take that as the identifier and
|
||||
# treat the next part as a sub_attribute
|
||||
if value in self.schema_urns:
|
||||
self.current_char = "."
|
||||
return value
|
||||
|
||||
# Handle dots within URN identifiers (like "2.0")
|
||||
# A dot is part of the identifier if it's followed by a digit
|
||||
if (
|
||||
self.current_char == "."
|
||||
and self.pos + 1 < len(self.text)
|
||||
and self.text[self.pos + 1].isdigit()
|
||||
):
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
# Continue reading digits after the dot
|
||||
while self.current_char and self.current_char.isdigit():
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
|
||||
return value
|
||||
|
||||
def get_next_token(self) -> Token: # noqa PLR0911
|
||||
"""Get the next token from the input"""
|
||||
while self.current_char:
|
||||
if self.current_char.isspace():
|
||||
self.skip_whitespace()
|
||||
continue
|
||||
|
||||
if self.current_char == ".":
|
||||
self.advance()
|
||||
return Token(TokenType.DOT, ".")
|
||||
|
||||
if self.current_char == "[":
|
||||
self.advance()
|
||||
return Token(TokenType.LBRACKET, "[")
|
||||
|
||||
if self.current_char == "]":
|
||||
self.advance()
|
||||
return Token(TokenType.RBRACKET, "]")
|
||||
|
||||
if self.current_char == "(":
|
||||
self.advance()
|
||||
return Token(TokenType.LPAREN, "(")
|
||||
|
||||
if self.current_char == ")":
|
||||
self.advance()
|
||||
return Token(TokenType.RPAREN, ")")
|
||||
|
||||
if self.current_char in "\"'":
|
||||
quote_char = self.current_char
|
||||
value = self.read_string(quote_char)
|
||||
return Token(TokenType.STRING, value)
|
||||
|
||||
if self.current_char.isdigit():
|
||||
value = self.read_number()
|
||||
return Token(TokenType.NUMBER, value)
|
||||
|
||||
if self.current_char.isalpha() or self.current_char == "_":
|
||||
value = self.read_identifier()
|
||||
|
||||
# Check for special keywords
|
||||
if value.lower() == "true":
|
||||
return Token(TokenType.BOOLEAN, True)
|
||||
elif value.lower() == "false":
|
||||
return Token(TokenType.BOOLEAN, False)
|
||||
elif value.lower() == "null":
|
||||
return Token(TokenType.NULL, None)
|
||||
elif value.lower() == "and":
|
||||
return Token(TokenType.AND, "and")
|
||||
elif value.lower() == "or":
|
||||
return Token(TokenType.OR, "or")
|
||||
elif value.lower() == "not":
|
||||
return Token(TokenType.NOT, "not")
|
||||
elif value.lower() in self.OPERATORS:
|
||||
return Token(TokenType.OPERATOR, value.lower())
|
||||
else:
|
||||
return Token(TokenType.ATTRIBUTE, value)
|
||||
|
||||
# Skip unknown characters
|
||||
self.advance()
|
||||
|
||||
return Token(TokenType.EOF, "")
|
||||
131
authentik/sources/scim/patch/parser.py
Normal file
131
authentik/sources/scim/patch/parser.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.scim.patch.lexer import SCIMPathLexer, TokenType
|
||||
|
||||
|
||||
class SCIMPathParser:
|
||||
"""Parser for SCIM paths including filter expressions"""
|
||||
|
||||
def __init__(self):
|
||||
self.lexer = None
|
||||
self.current_token = None
|
||||
|
||||
def parse_path(self, path: str | None) -> list[dict[str, Any]]:
|
||||
"""Parse a SCIM path into components"""
|
||||
self.lexer = SCIMPathLexer(path)
|
||||
self.current_token = self.lexer.get_next_token()
|
||||
|
||||
components = []
|
||||
|
||||
while self.current_token.type != TokenType.EOF:
|
||||
component = self._parse_path_component()
|
||||
if component:
|
||||
components.append(component)
|
||||
|
||||
return components
|
||||
|
||||
def _parse_path_component(self) -> dict[str, Any] | None:
|
||||
"""Parse a single path component"""
|
||||
if self.current_token.type != TokenType.ATTRIBUTE:
|
||||
return None
|
||||
|
||||
attribute = self.current_token.value
|
||||
self._consume(TokenType.ATTRIBUTE)
|
||||
|
||||
filter_expr = None
|
||||
sub_attribute = None
|
||||
|
||||
# Check for filter expression
|
||||
if self.current_token.type == TokenType.LBRACKET:
|
||||
self._consume(TokenType.LBRACKET)
|
||||
filter_expr = self._parse_filter_expression()
|
||||
self._consume(TokenType.RBRACKET)
|
||||
|
||||
# Check for sub-attribute
|
||||
if self.current_token.type == TokenType.DOT:
|
||||
self._consume(TokenType.DOT)
|
||||
if self.current_token.type == TokenType.ATTRIBUTE:
|
||||
sub_attribute = self.current_token.value
|
||||
self._consume(TokenType.ATTRIBUTE)
|
||||
|
||||
return {"attribute": attribute, "filter": filter_expr, "sub_attribute": sub_attribute}
|
||||
|
||||
def _parse_filter_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse a filter expression like 'primary eq true' or
|
||||
'type eq "work" and primary eq true'"""
|
||||
return self._parse_or_expression()
|
||||
|
||||
def _parse_or_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse OR expressions"""
|
||||
left = self._parse_and_expression()
|
||||
|
||||
while self.current_token.type == TokenType.OR:
|
||||
self._consume(TokenType.OR)
|
||||
right = self._parse_and_expression()
|
||||
left = {"type": "logical", "operator": "or", "left": left, "right": right}
|
||||
|
||||
return left
|
||||
|
||||
def _parse_and_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse AND expressions"""
|
||||
left = self._parse_primary_expression()
|
||||
|
||||
while self.current_token.type == TokenType.AND:
|
||||
self._consume(TokenType.AND)
|
||||
right = self._parse_primary_expression()
|
||||
left = {"type": "logical", "operator": "and", "left": left, "right": right}
|
||||
|
||||
return left
|
||||
|
||||
def _parse_primary_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse primary expressions (attribute operator value)"""
|
||||
if self.current_token.type == TokenType.LPAREN:
|
||||
self._consume(TokenType.LPAREN)
|
||||
expr = self._parse_or_expression()
|
||||
self._consume(TokenType.RPAREN)
|
||||
return expr
|
||||
|
||||
if self.current_token.type == TokenType.NOT:
|
||||
self._consume(TokenType.NOT)
|
||||
expr = self._parse_primary_expression()
|
||||
return {"type": "logical", "operator": "not", "operand": expr}
|
||||
|
||||
if self.current_token.type != TokenType.ATTRIBUTE:
|
||||
return None
|
||||
|
||||
attribute = self.current_token.value
|
||||
self._consume(TokenType.ATTRIBUTE)
|
||||
|
||||
if self.current_token.type != TokenType.OPERATOR:
|
||||
return None
|
||||
|
||||
operator = self.current_token.value
|
||||
self._consume(TokenType.OPERATOR)
|
||||
|
||||
# Parse value
|
||||
value = None
|
||||
if self.current_token.type == TokenType.STRING:
|
||||
value = self.current_token.value
|
||||
self._consume(TokenType.STRING)
|
||||
elif self.current_token.type == TokenType.NUMBER:
|
||||
value = (
|
||||
float(self.current_token.value)
|
||||
if "." in self.current_token.value
|
||||
else int(self.current_token.value)
|
||||
)
|
||||
self._consume(TokenType.NUMBER)
|
||||
elif self.current_token.type == TokenType.BOOLEAN:
|
||||
value = self.current_token.value
|
||||
self._consume(TokenType.BOOLEAN)
|
||||
elif self.current_token.type == TokenType.NULL:
|
||||
value = None
|
||||
self._consume(TokenType.NULL)
|
||||
|
||||
return {"type": "comparison", "attribute": attribute, "operator": operator, "value": value}
|
||||
|
||||
def _consume(self, expected_type: TokenType):
|
||||
"""Consume a token of the expected type"""
|
||||
if self.current_token.type == expected_type:
|
||||
self.current_token = self.lexer.get_next_token()
|
||||
else:
|
||||
raise ValueError(f"Expected {expected_type}, got {self.current_token.type}")
|
||||
246
authentik/sources/scim/patch/processor.py
Normal file
246
authentik/sources/scim/patch/processor.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from typing import Any
|
||||
|
||||
from authentik.providers.scim.clients.schema import PatchOp, PatchOperation
|
||||
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||
from authentik.sources.scim.patch.parser import SCIMPathParser
|
||||
|
||||
|
||||
class SCIMPatchProcessor:
|
||||
"""Processes SCIM patch operations on Python dictionaries"""
|
||||
|
||||
def __init__(self):
|
||||
self.parser = SCIMPathParser()
|
||||
|
||||
def apply_patches(self, data: dict[str, Any], patches: list[PatchOperation]) -> dict[str, Any]:
|
||||
"""Apply a list of patch operations to the data"""
|
||||
result = data.copy()
|
||||
|
||||
for _patch in patches:
|
||||
patch = PatchOperation.model_validate(_patch)
|
||||
if patch.path is None:
|
||||
# Handle operations with no path - value contains attribute paths as keys
|
||||
self._apply_bulk_operation(result, patch.op, patch.value)
|
||||
elif patch.op == PatchOp.add:
|
||||
self._apply_add(result, patch.path, patch.value)
|
||||
elif patch.op == PatchOp.remove:
|
||||
self._apply_remove(result, patch.path)
|
||||
elif patch.op == PatchOp.replace:
|
||||
self._apply_replace(result, patch.path, patch.value)
|
||||
|
||||
return result
|
||||
|
||||
def _apply_bulk_operation(
|
||||
self, data: dict[str, Any], operation: PatchOp, value: dict[str, Any]
|
||||
):
|
||||
"""Apply bulk operations when path is None"""
|
||||
if not isinstance(value, dict):
|
||||
return
|
||||
for path, val in value.items():
|
||||
if operation == PatchOp.add:
|
||||
self._apply_add(data, path, val)
|
||||
elif operation == PatchOp.remove:
|
||||
self._apply_remove(data, path)
|
||||
elif operation == PatchOp.replace:
|
||||
self._apply_replace(data, path, val)
|
||||
|
||||
def _apply_add(self, data: dict[str, Any], path: str, value: Any):
|
||||
"""Apply ADD operation"""
|
||||
components = self.parser.parse_path(path)
|
||||
|
||||
if len(components) == 1 and not components[0]["filter"]:
|
||||
# Simple path
|
||||
attr = components[0]["attribute"]
|
||||
if components[0]["sub_attribute"]:
|
||||
if attr not in data:
|
||||
data[attr] = {}
|
||||
# Somewhat hacky workaround for the manager attribute of the enterprise schema
|
||||
# ideally we'd do this based on the schema
|
||||
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
|
||||
data[attr][components[0]["sub_attribute"]] = {"value": value}
|
||||
else:
|
||||
data[attr][components[0]["sub_attribute"]] = value
|
||||
elif attr in data:
|
||||
data[attr].append(value)
|
||||
else:
|
||||
data[attr] = value
|
||||
else:
|
||||
# Complex path with filters
|
||||
self._navigate_and_modify(data, components, value, "add")
|
||||
|
||||
def _apply_remove(self, data: dict[str, Any], path: str):
|
||||
"""Apply REMOVE operation"""
|
||||
components = self.parser.parse_path(path)
|
||||
|
||||
if len(components) == 1 and not components[0]["filter"]:
|
||||
# Simple path
|
||||
attr = components[0]["attribute"]
|
||||
if components[0]["sub_attribute"]:
|
||||
if attr in data and isinstance(data[attr], dict):
|
||||
data[attr].pop(components[0]["sub_attribute"], None)
|
||||
else:
|
||||
data.pop(attr, None)
|
||||
else:
|
||||
# Complex path with filters
|
||||
self._navigate_and_modify(data, components, None, "remove")
|
||||
|
||||
def _apply_replace(self, data: dict[str, Any], path: str, value: Any):
|
||||
"""Apply REPLACE operation"""
|
||||
components = self.parser.parse_path(path)
|
||||
|
||||
if len(components) == 1 and not components[0]["filter"]:
|
||||
# Simple path
|
||||
attr = components[0]["attribute"]
|
||||
if components[0]["sub_attribute"]:
|
||||
if attr not in data:
|
||||
data[attr] = {}
|
||||
# Somewhat hacky workaround for the manager attribute of the enterprise schema
|
||||
# ideally we'd do this based on the schema
|
||||
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
|
||||
data[attr][components[0]["sub_attribute"]] = {"value": value}
|
||||
else:
|
||||
data[attr][components[0]["sub_attribute"]] = value
|
||||
else:
|
||||
data[attr] = value
|
||||
else:
|
||||
# Complex path with filters
|
||||
self._navigate_and_modify(data, components, value, "replace")
|
||||
|
||||
def _navigate_and_modify( # noqa PLR0912
|
||||
self, data: dict[str, Any], components: list[dict[str, Any]], value: Any, operation: str
|
||||
):
|
||||
"""Navigate through complex paths and apply modifications"""
|
||||
current = data
|
||||
|
||||
for i, component in enumerate(components):
|
||||
attr = component["attribute"]
|
||||
filter_expr = component["filter"]
|
||||
sub_attr = component["sub_attribute"]
|
||||
|
||||
if filter_expr:
|
||||
# Handle array with filter
|
||||
if attr not in current:
|
||||
if operation == "add":
|
||||
current[attr] = []
|
||||
else:
|
||||
return
|
||||
|
||||
if not isinstance(current[attr], list):
|
||||
return
|
||||
|
||||
# Find matching items
|
||||
matching_items = []
|
||||
for item in current[attr]:
|
||||
if self._matches_filter(item, filter_expr):
|
||||
matching_items.append(item)
|
||||
|
||||
if not matching_items and operation == "add":
|
||||
# Create new item if none match (only for simple comparison filters)
|
||||
if filter_expr.get("type", "comparison") == "comparison":
|
||||
new_item = {filter_expr["attribute"]: filter_expr["value"]}
|
||||
current[attr].append(new_item)
|
||||
matching_items = [new_item]
|
||||
|
||||
# Apply operation to matching items
|
||||
for item in matching_items:
|
||||
if sub_attr:
|
||||
if operation in {"add", "replace"}:
|
||||
item[sub_attr] = value
|
||||
elif operation == "remove":
|
||||
item.pop(sub_attr, None)
|
||||
elif operation in {"add", "replace"}:
|
||||
if isinstance(value, dict):
|
||||
item.update(value)
|
||||
else:
|
||||
# If value is not a dict, we can't merge it
|
||||
pass
|
||||
elif operation == "remove":
|
||||
# Remove the entire item
|
||||
if item in current[attr]:
|
||||
current[attr].remove(item)
|
||||
# Handle simple attribute
|
||||
elif i == len(components) - 1:
|
||||
# Last component
|
||||
if sub_attr:
|
||||
if attr not in current:
|
||||
current[attr] = {}
|
||||
if operation in {"add", "replace"}:
|
||||
current[attr][sub_attr] = value
|
||||
elif operation == "remove":
|
||||
current[attr].pop(sub_attr, None)
|
||||
elif operation in {"add", "replace"}:
|
||||
current[attr] = value
|
||||
elif operation == "remove":
|
||||
current.pop(attr, None)
|
||||
else:
|
||||
# Navigate deeper
|
||||
if attr not in current:
|
||||
current[attr] = {}
|
||||
current = current[attr]
|
||||
|
||||
def _matches_filter(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
|
||||
"""Check if an item matches the filter expression"""
|
||||
if not filter_expr:
|
||||
return True
|
||||
|
||||
filter_type = filter_expr.get("type", "comparison")
|
||||
|
||||
if filter_type == "comparison":
|
||||
return self._matches_comparison(item, filter_expr)
|
||||
elif filter_type == "logical":
|
||||
return self._matches_logical(item, filter_expr)
|
||||
|
||||
return False
|
||||
|
||||
def _matches_comparison( # noqa PLR0912
|
||||
self, item: dict[str, Any], filter_expr: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if an item matches a comparison filter"""
|
||||
attr = filter_expr["attribute"]
|
||||
operator = filter_expr["operator"]
|
||||
expected_value = filter_expr["value"]
|
||||
|
||||
if attr not in item:
|
||||
return False
|
||||
|
||||
actual_value = item[attr]
|
||||
|
||||
if operator == "eq":
|
||||
return actual_value == expected_value
|
||||
elif operator == "ne":
|
||||
return actual_value != expected_value
|
||||
elif operator == "co":
|
||||
return str(expected_value) in str(actual_value)
|
||||
elif operator == "sw":
|
||||
return str(actual_value).startswith(str(expected_value))
|
||||
elif operator == "ew":
|
||||
return str(actual_value).endswith(str(expected_value))
|
||||
elif operator == "gt":
|
||||
return actual_value > expected_value
|
||||
elif operator == "lt":
|
||||
return actual_value < expected_value
|
||||
elif operator == "ge":
|
||||
return actual_value >= expected_value
|
||||
elif operator == "le":
|
||||
return actual_value <= expected_value
|
||||
elif operator == "pr":
|
||||
return actual_value is not None
|
||||
|
||||
return False
|
||||
|
||||
def _matches_logical(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
|
||||
"""Check if an item matches a logical filter expression"""
|
||||
operator = filter_expr["operator"]
|
||||
|
||||
if operator == "and":
|
||||
left_result = self._matches_filter(item, filter_expr["left"])
|
||||
right_result = self._matches_filter(item, filter_expr["right"])
|
||||
return left_result and right_result
|
||||
elif operator == "or":
|
||||
left_result = self._matches_filter(item, filter_expr["left"])
|
||||
right_result = self._matches_filter(item, filter_expr["right"])
|
||||
return left_result or right_result
|
||||
elif operator == "not":
|
||||
operand_result = self._matches_filter(item, filter_expr["operand"])
|
||||
return not operand_result
|
||||
|
||||
return False
|
||||
@@ -1101,17 +1101,6 @@
|
||||
"returned": "default",
|
||||
"uniqueness": "none"
|
||||
},
|
||||
{
|
||||
"name": "password",
|
||||
"type": "string",
|
||||
"multiValued": false,
|
||||
"description": "The User's cleartext password. This attribute is intended to be used as a means to specify an initial\npassword when creating a new User or to reset an existing User's password.",
|
||||
"required": false,
|
||||
"caseExact": false,
|
||||
"mutability": "writeOnly",
|
||||
"returned": "never",
|
||||
"uniqueness": "none"
|
||||
},
|
||||
{
|
||||
"name": "emails",
|
||||
"type": "complex",
|
||||
|
||||
@@ -75,7 +75,9 @@ class TestSCIMGroups(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||
)
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
@@ -86,6 +88,7 @@ class TestSCIMGroups(APITestCase):
|
||||
"""Test group create"""
|
||||
user = create_test_user()
|
||||
ext_id = generate_id()
|
||||
name = generate_id()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -95,7 +98,7 @@ class TestSCIMGroups(APITestCase):
|
||||
),
|
||||
data=dumps(
|
||||
{
|
||||
"displayName": generate_id(),
|
||||
"displayName": name,
|
||||
"externalId": ext_id,
|
||||
"members": [{"value": str(user.uuid)}],
|
||||
}
|
||||
@@ -104,12 +107,22 @@ class TestSCIMGroups(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
||||
connection = SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).first()
|
||||
self.assertIsNotNone(connection)
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
).exists()
|
||||
)
|
||||
connection.refresh_from_db()
|
||||
self.assertEqual(
|
||||
connection.attributes,
|
||||
{
|
||||
"displayName": name,
|
||||
"externalId": ext_id,
|
||||
"members": [{"value": str(user.uuid)}],
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_create_members_empty(self):
|
||||
"""Test group create"""
|
||||
@@ -126,7 +139,9 @@ class TestSCIMGroups(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||
)
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
@@ -136,7 +151,9 @@ class TestSCIMGroups(APITestCase):
|
||||
def test_group_create_duplicate(self):
|
||||
"""Test group create (duplicate)"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
existing = SCIMSourceGroup.objects.create(
|
||||
source=self.source, group=group, external_id=uuid4()
|
||||
)
|
||||
ext_id = generate_id()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
@@ -165,7 +182,9 @@ class TestSCIMGroups(APITestCase):
|
||||
def test_group_update(self):
|
||||
"""Test group update"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
existing = SCIMSourceGroup.objects.create(
|
||||
source=self.source, group=group, external_id=uuid4()
|
||||
)
|
||||
ext_id = generate_id()
|
||||
response = self.client.put(
|
||||
reverse(
|
||||
@@ -205,12 +224,49 @@ class TestSCIMGroups(APITestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_patch_add(self):
|
||||
def test_group_patch_modify(self):
|
||||
"""Test group patch"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
connection = SCIMSourceGroup.objects.create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
external_id=uuid4(),
|
||||
attributes={"displayName": group.name, "members": []},
|
||||
)
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
|
||||
),
|
||||
data=dumps(
|
||||
{
|
||||
"Operations": [
|
||||
{
|
||||
"op": "Add",
|
||||
"value": {"externalId": "d85051cb-0557-4aa1-98ca-51eabcee4d40"},
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200, response.content)
|
||||
connection = SCIMSourceGroup.objects.filter(id="d85051cb-0557-4aa1-98ca-51eabcee4d40")
|
||||
self.assertIsNotNone(connection)
|
||||
|
||||
def test_group_patch_member_add(self):
|
||||
"""Test group patch"""
|
||||
user = create_test_user()
|
||||
|
||||
other_user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
group.users.add(other_user)
|
||||
connection = SCIMSourceGroup.objects.create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
external_id=uuid4(),
|
||||
attributes={"displayName": group.name, "members": [{"value": str(other_user.uuid)}]},
|
||||
)
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -222,7 +278,7 @@ class TestSCIMGroups(APITestCase):
|
||||
{
|
||||
"op": "Add",
|
||||
"path": "members",
|
||||
"value": {"value": str(user.uuid)},
|
||||
"value": [{"value": str(user.uuid)}],
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -230,16 +286,33 @@ class TestSCIMGroups(APITestCase):
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, second=200)
|
||||
self.assertEqual(response.status_code, 200, response.content)
|
||||
self.assertTrue(group.users.filter(pk=user.pk).exists())
|
||||
self.assertTrue(group.users.filter(pk=other_user.pk).exists())
|
||||
connection.refresh_from_db()
|
||||
self.assertEqual(
|
||||
connection.attributes,
|
||||
{
|
||||
"displayName": group.name,
|
||||
"members": sorted(
|
||||
[{"value": str(other_user.uuid)}, {"value": str(user.uuid)}],
|
||||
key=lambda u: u["value"],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_patch_remove(self):
|
||||
def test_group_patch_member_remove(self):
|
||||
"""Test group patch"""
|
||||
user = create_test_user()
|
||||
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(user)
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
connection = SCIMSourceGroup.objects.create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
external_id=uuid4(),
|
||||
attributes={"displayName": group.name, "members": []},
|
||||
)
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -251,7 +324,7 @@ class TestSCIMGroups(APITestCase):
|
||||
{
|
||||
"op": "remove",
|
||||
"path": "members",
|
||||
"value": {"value": str(user.uuid)},
|
||||
"value": [{"value": str(user.uuid)}],
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -259,13 +332,21 @@ class TestSCIMGroups(APITestCase):
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, second=200)
|
||||
self.assertEqual(response.status_code, 200, response.content)
|
||||
self.assertFalse(group.users.filter(pk=user.pk).exists())
|
||||
connection.refresh_from_db()
|
||||
self.assertEqual(
|
||||
connection.attributes,
|
||||
{
|
||||
"displayName": group.name,
|
||||
"members": [],
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_delete(self):
|
||||
"""Test group delete"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, external_id=uuid4())
|
||||
response = self.client.delete(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
|
||||
510
authentik/sources/scim/tests/test_lexer.py
Normal file
510
authentik/sources/scim/tests/test_lexer.py
Normal file
@@ -0,0 +1,510 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from authentik.sources.scim.constants import (
|
||||
SCIM_URN_GROUP,
|
||||
SCIM_URN_SCHEMA,
|
||||
SCIM_URN_USER,
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
)
|
||||
from authentik.sources.scim.patch.lexer import SCIMPathLexer, Token, TokenType
|
||||
|
||||
|
||||
class TestTokenType(TestCase):
|
||||
"""Test TokenType enum"""
|
||||
|
||||
def test_token_type_values(self):
|
||||
"""Test that all token types have correct values"""
|
||||
self.assertEqual(TokenType.ATTRIBUTE.value, "ATTRIBUTE")
|
||||
self.assertEqual(TokenType.DOT.value, "DOT")
|
||||
self.assertEqual(TokenType.LBRACKET.value, "LBRACKET")
|
||||
self.assertEqual(TokenType.RBRACKET.value, "RBRACKET")
|
||||
self.assertEqual(TokenType.LPAREN.value, "LPAREN")
|
||||
self.assertEqual(TokenType.RPAREN.value, "RPAREN")
|
||||
self.assertEqual(TokenType.STRING.value, "STRING")
|
||||
self.assertEqual(TokenType.NUMBER.value, "NUMBER")
|
||||
self.assertEqual(TokenType.BOOLEAN.value, "BOOLEAN")
|
||||
self.assertEqual(TokenType.NULL.value, "NULL")
|
||||
self.assertEqual(TokenType.OPERATOR.value, "OPERATOR")
|
||||
self.assertEqual(TokenType.AND.value, "AND")
|
||||
self.assertEqual(TokenType.OR.value, "OR")
|
||||
self.assertEqual(TokenType.NOT.value, "NOT")
|
||||
self.assertEqual(TokenType.EOF.value, "EOF")
|
||||
|
||||
|
||||
class TestToken(TestCase):
|
||||
"""Test Token dataclass"""
|
||||
|
||||
def test_token_creation(self):
|
||||
"""Test token creation with all parameters"""
|
||||
token = Token(TokenType.ATTRIBUTE, "userName", 5)
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
self.assertEqual(token.position, 5)
|
||||
|
||||
def test_token_creation_default_position(self):
|
||||
"""Test token creation with default position"""
|
||||
token = Token(TokenType.DOT, ".")
|
||||
self.assertEqual(token.type, TokenType.DOT)
|
||||
self.assertEqual(token.value, ".")
|
||||
self.assertEqual(token.position, 0)
|
||||
|
||||
|
||||
class TestSCIMPathLexer(TestCase):
|
||||
"""Test SCIMPathLexer class"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.simple_lexer = SCIMPathLexer("userName")
|
||||
|
||||
def test_init(self):
|
||||
"""Test lexer initialization"""
|
||||
lexer = SCIMPathLexer("test")
|
||||
self.assertEqual(lexer.text, "test")
|
||||
self.assertEqual(lexer.pos, 0)
|
||||
self.assertEqual(lexer.current_char, "t")
|
||||
self.assertIn(SCIM_URN_SCHEMA, lexer.schema_urns)
|
||||
self.assertIn(SCIM_URN_GROUP, lexer.schema_urns)
|
||||
self.assertIn(SCIM_URN_USER, lexer.schema_urns)
|
||||
self.assertIn(SCIM_URN_USER_ENTERPRISE, lexer.schema_urns)
|
||||
self.assertEqual(
|
||||
lexer.OPERATORS, ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||
)
|
||||
|
||||
def test_init_empty_string(self):
|
||||
"""Test lexer initialization with empty string"""
|
||||
lexer = SCIMPathLexer("")
|
||||
self.assertEqual(lexer.text, "")
|
||||
self.assertEqual(lexer.pos, 0)
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_advance(self):
|
||||
"""Test advance method"""
|
||||
lexer = SCIMPathLexer("abc")
|
||||
self.assertEqual(lexer.current_char, "a")
|
||||
|
||||
lexer.advance()
|
||||
self.assertEqual(lexer.pos, 1)
|
||||
self.assertEqual(lexer.current_char, "b")
|
||||
|
||||
lexer.advance()
|
||||
self.assertEqual(lexer.pos, 2)
|
||||
self.assertEqual(lexer.current_char, "c")
|
||||
|
||||
lexer.advance()
|
||||
self.assertEqual(lexer.pos, 3)
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_skip_whitespace(self):
|
||||
"""Test skip_whitespace method"""
|
||||
lexer = SCIMPathLexer(" \t\n abc")
|
||||
lexer.skip_whitespace()
|
||||
self.assertEqual(lexer.current_char, "a")
|
||||
|
||||
def test_skip_whitespace_only_whitespace(self):
|
||||
"""Test skip_whitespace with only whitespace"""
|
||||
lexer = SCIMPathLexer(" \t\n ")
|
||||
lexer.skip_whitespace()
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_skip_whitespace_no_whitespace(self):
|
||||
"""Test skip_whitespace with no leading whitespace"""
|
||||
lexer = SCIMPathLexer("abc")
|
||||
original_pos = lexer.pos
|
||||
lexer.skip_whitespace()
|
||||
self.assertEqual(lexer.pos, original_pos)
|
||||
self.assertEqual(lexer.current_char, "a")
|
||||
|
||||
def test_read_string_double_quotes(self):
|
||||
"""Test reading double-quoted string"""
|
||||
lexer = SCIMPathLexer('"hello world"')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, "hello world")
|
||||
self.assertIsNone(lexer.current_char) # Should be at end
|
||||
|
||||
def test_read_string_single_quotes(self):
|
||||
"""Test reading single-quoted string"""
|
||||
lexer = SCIMPathLexer("'hello world'")
|
||||
result = lexer.read_string("'")
|
||||
self.assertEqual(result, "hello world")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_string_with_escapes(self):
|
||||
"""Test reading string with escape characters"""
|
||||
lexer = SCIMPathLexer('"hello \\"world\\""')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, 'hello "world"')
|
||||
|
||||
def test_read_string_with_backslash_at_end(self):
|
||||
"""Test reading string with backslash at end"""
|
||||
lexer = SCIMPathLexer('"hello\\"')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, 'hello"')
|
||||
|
||||
def test_read_string_unclosed(self):
|
||||
"""Test reading unclosed string"""
|
||||
lexer = SCIMPathLexer('"hello world')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, "hello world")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_string_empty(self):
|
||||
"""Test reading empty string"""
|
||||
lexer = SCIMPathLexer('""')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, "")
|
||||
|
||||
def test_read_number_integer(self):
|
||||
"""Test reading integer number"""
|
||||
lexer = SCIMPathLexer("123")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, "123")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_number_float(self):
|
||||
"""Test reading float number"""
|
||||
lexer = SCIMPathLexer("123.456")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, "123.456")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_number_with_multiple_dots(self):
|
||||
"""Test reading number with multiple dots (invalid but handled)"""
|
||||
lexer = SCIMPathLexer("123.456.789")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, "123.456.789")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_number_starting_with_dot(self):
|
||||
"""Test reading number starting with dot"""
|
||||
lexer = SCIMPathLexer(".123")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, ".123")
|
||||
|
||||
def test_read_identifier_simple(self):
|
||||
"""Test reading simple identifier"""
|
||||
lexer = SCIMPathLexer("userName")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "userName")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_identifier_with_underscore(self):
|
||||
"""Test reading identifier with underscore"""
|
||||
lexer = SCIMPathLexer("user_name")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "user_name")
|
||||
|
||||
def test_read_identifier_with_hyphen(self):
|
||||
"""Test reading identifier with hyphen"""
|
||||
lexer = SCIMPathLexer("user-name")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "user-name")
|
||||
|
||||
def test_read_identifier_with_colon(self):
|
||||
"""Test reading identifier with colon (URN format)"""
|
||||
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
|
||||
def test_read_identifier_schema_urn(self):
|
||||
"""Test reading schema URN identifier"""
|
||||
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, SCIM_URN_USER)
|
||||
self.assertEqual(lexer.current_char, ".") # Should stop at dot and set current_char to dot
|
||||
|
||||
def test_read_identifier_with_version_number(self):
|
||||
"""Test reading identifier with version number (dots followed by digits)"""
|
||||
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
|
||||
def test_read_identifier_partial_urn_match(self):
|
||||
"""Test reading identifier that partially matches URN"""
|
||||
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:CustomUser")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:CustomUser")
|
||||
|
||||
# Test get_next_token method
|
||||
def test_get_next_token_dot(self):
|
||||
"""Test tokenizing dot"""
|
||||
lexer = SCIMPathLexer(".")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.DOT)
|
||||
self.assertEqual(token.value, ".")
|
||||
|
||||
def test_get_next_token_lbracket(self):
|
||||
"""Test tokenizing left bracket"""
|
||||
lexer = SCIMPathLexer("[")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.LBRACKET)
|
||||
self.assertEqual(token.value, "[")
|
||||
|
||||
def test_get_next_token_rbracket(self):
|
||||
"""Test tokenizing right bracket"""
|
||||
lexer = SCIMPathLexer("]")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.RBRACKET)
|
||||
self.assertEqual(token.value, "]")
|
||||
|
||||
def test_get_next_token_lparen(self):
|
||||
"""Test tokenizing left parenthesis"""
|
||||
lexer = SCIMPathLexer("(")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.LPAREN)
|
||||
self.assertEqual(token.value, "(")
|
||||
|
||||
def test_get_next_token_rparen(self):
|
||||
"""Test tokenizing right parenthesis"""
|
||||
lexer = SCIMPathLexer(")")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.RPAREN)
|
||||
self.assertEqual(token.value, ")")
|
||||
|
||||
def test_get_next_token_string_double_quotes(self):
|
||||
"""Test tokenizing double-quoted string"""
|
||||
lexer = SCIMPathLexer('"test string"')
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.STRING)
|
||||
self.assertEqual(token.value, "test string")
|
||||
|
||||
def test_get_next_token_string_single_quotes(self):
|
||||
"""Test tokenizing single-quoted string"""
|
||||
lexer = SCIMPathLexer("'test string'")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.STRING)
|
||||
self.assertEqual(token.value, "test string")
|
||||
|
||||
def test_get_next_token_number_integer(self):
|
||||
"""Test tokenizing integer"""
|
||||
lexer = SCIMPathLexer("123")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NUMBER)
|
||||
self.assertEqual(token.value, "123")
|
||||
|
||||
def test_get_next_token_number_float(self):
|
||||
"""Test tokenizing float"""
|
||||
lexer = SCIMPathLexer("123.45")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NUMBER)
|
||||
self.assertEqual(token.value, "123.45")
|
||||
|
||||
def test_get_next_token_boolean_true(self):
|
||||
"""Test tokenizing boolean true"""
|
||||
lexer = SCIMPathLexer("true")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||
self.assertTrue(token.value)
|
||||
|
||||
def test_get_next_token_boolean_false(self):
|
||||
"""Test tokenizing boolean false"""
|
||||
lexer = SCIMPathLexer("false")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||
self.assertFalse(token.value)
|
||||
|
||||
def test_get_next_token_boolean_case_insensitive(self):
|
||||
"""Test tokenizing boolean with different cases"""
|
||||
for value in ["TRUE", "True", "FALSE", "False"]:
|
||||
with self.subTest(value=value):
|
||||
lexer = SCIMPathLexer(value)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||
|
||||
def test_get_next_token_null(self):
|
||||
"""Test tokenizing null"""
|
||||
lexer = SCIMPathLexer("null")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NULL)
|
||||
self.assertIsNone(token.value)
|
||||
|
||||
def test_get_next_token_null_case_insensitive(self):
|
||||
"""Test tokenizing null with different cases"""
|
||||
for value in ["NULL", "Null"]:
|
||||
with self.subTest(value=value):
|
||||
lexer = SCIMPathLexer(value)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NULL)
|
||||
|
||||
def test_get_next_token_and(self):
|
||||
"""Test tokenizing AND operator"""
|
||||
lexer = SCIMPathLexer("and")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.AND)
|
||||
self.assertEqual(token.value, "and")
|
||||
|
||||
def test_get_next_token_or(self):
|
||||
"""Test tokenizing OR operator"""
|
||||
lexer = SCIMPathLexer("or")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.OR)
|
||||
self.assertEqual(token.value, "or")
|
||||
|
||||
def test_get_next_token_not(self):
|
||||
"""Test tokenizing NOT operator"""
|
||||
lexer = SCIMPathLexer("not")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NOT)
|
||||
self.assertEqual(token.value, "not")
|
||||
|
||||
def test_get_next_token_operators(self):
|
||||
"""Test tokenizing all comparison operators"""
|
||||
operators = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||
for op in operators:
|
||||
with self.subTest(operator=op):
|
||||
lexer = SCIMPathLexer(op)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.OPERATOR)
|
||||
self.assertEqual(token.value, op)
|
||||
|
||||
def test_get_next_token_operators_case_insensitive(self):
|
||||
"""Test tokenizing operators with different cases"""
|
||||
for op in ["EQ", "Eq", "NE", "Ne"]:
|
||||
with self.subTest(operator=op):
|
||||
lexer = SCIMPathLexer(op)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.OPERATOR)
|
||||
self.assertEqual(token.value, op.lower())
|
||||
|
||||
def test_get_next_token_attribute(self):
|
||||
"""Test tokenizing attribute name"""
|
||||
lexer = SCIMPathLexer("userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
|
||||
def test_get_next_token_attribute_with_underscore(self):
|
||||
"""Test tokenizing attribute name with underscore"""
|
||||
lexer = SCIMPathLexer("_userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "_userName")
|
||||
|
||||
def test_get_next_token_eof(self):
|
||||
"""Test tokenizing end of file"""
|
||||
lexer = SCIMPathLexer("")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.EOF)
|
||||
self.assertEqual(token.value, "")
|
||||
|
||||
def test_get_next_token_with_whitespace(self):
|
||||
"""Test tokenizing with leading whitespace"""
|
||||
lexer = SCIMPathLexer(" userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
|
||||
def test_get_next_token_skip_unknown_characters(self):
|
||||
"""Test that unknown characters are skipped"""
|
||||
lexer = SCIMPathLexer("@#$userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
|
||||
def test_get_next_token_multiple_tokens(self):
|
||||
"""Test tokenizing multiple tokens in sequence"""
|
||||
lexer = SCIMPathLexer("userName.givenName")
|
||||
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token1.value, "userName")
|
||||
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.DOT)
|
||||
self.assertEqual(token2.value, ".")
|
||||
|
||||
token3 = lexer.get_next_token()
|
||||
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token3.value, "givenName")
|
||||
|
||||
token4 = lexer.get_next_token()
|
||||
self.assertEqual(token4.type, TokenType.EOF)
|
||||
|
||||
def test_get_next_token_complex_filter(self):
|
||||
"""Test tokenizing complex filter expression"""
|
||||
lexer = SCIMPathLexer('emails[type eq "work" and primary eq true]')
|
||||
|
||||
tokens = []
|
||||
while True:
|
||||
token = lexer.get_next_token()
|
||||
tokens.append(token)
|
||||
if token.type == TokenType.EOF:
|
||||
break
|
||||
|
||||
expected_types = [
|
||||
TokenType.ATTRIBUTE, # emails
|
||||
TokenType.LBRACKET, # [
|
||||
TokenType.ATTRIBUTE, # type
|
||||
TokenType.OPERATOR, # eq
|
||||
TokenType.STRING, # "work"
|
||||
TokenType.AND, # and
|
||||
TokenType.ATTRIBUTE, # primary
|
||||
TokenType.OPERATOR, # eq
|
||||
TokenType.BOOLEAN, # true
|
||||
TokenType.RBRACKET, # ]
|
||||
TokenType.EOF,
|
||||
]
|
||||
|
||||
self.assertEqual(len(tokens), len(expected_types))
|
||||
for token, expected_type in zip(tokens, expected_types, strict=False):
|
||||
self.assertEqual(token.type, expected_type)
|
||||
|
||||
def test_get_next_token_urn_attribute(self):
|
||||
"""Test tokenizing URN-based attribute"""
|
||||
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
|
||||
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token1.value, SCIM_URN_USER)
|
||||
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.DOT)
|
||||
|
||||
token3 = lexer.get_next_token()
|
||||
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token3.value, "userName")
|
||||
|
||||
def test_get_next_token_enterprise_urn(self):
|
||||
"""Test tokenizing enterprise URN"""
|
||||
lexer = SCIMPathLexer(f"{SCIM_URN_USER_ENTERPRISE}.manager")
|
||||
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token1.value, SCIM_URN_USER_ENTERPRISE)
|
||||
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.DOT)
|
||||
|
||||
def test_lexer_state_after_eof(self):
|
||||
"""Test lexer state after reaching EOF"""
|
||||
lexer = SCIMPathLexer("a")
|
||||
|
||||
# Get first token
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
|
||||
# Get EOF token
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.EOF)
|
||||
|
||||
# Should continue returning EOF
|
||||
token3 = lexer.get_next_token()
|
||||
self.assertEqual(token3.type, TokenType.EOF)
|
||||
|
||||
def test_read_identifier_edge_cases(self):
|
||||
"""Test read_identifier with edge cases"""
|
||||
# Test identifier ending with colon
|
||||
lexer = SCIMPathLexer("test:")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "test:")
|
||||
|
||||
# Test identifier with numbers
|
||||
lexer = SCIMPathLexer("test123")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "test123")
|
||||
|
||||
def test_complex_urn_parsing(self):
|
||||
"""Test parsing complex URN with version numbers"""
|
||||
urn = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
lexer = SCIMPathLexer(urn)
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, urn)
|
||||
1254
authentik/sources/scim/tests/test_patch.py
Normal file
1254
authentik/sources/scim/tests/test_patch.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -10,6 +10,7 @@ from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
||||
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||
from authentik.sources.scim.models import SCIMSource, SCIMSourcePropertyMapping, SCIMSourceUser
|
||||
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
|
||||
|
||||
@@ -81,7 +82,9 @@ class TestSCIMUsers(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(SCIMSourceUser.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
SCIMSourceUser.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||
)
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
@@ -174,14 +177,16 @@ class TestSCIMUsers(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertEqual(
|
||||
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
|
||||
SCIMSourceUser.objects.get(source=self.source, external_id=ext_id).user.attributes[
|
||||
"phone"
|
||||
],
|
||||
"0123456789",
|
||||
)
|
||||
|
||||
def test_user_update(self):
|
||||
"""Test user update"""
|
||||
user = create_test_user()
|
||||
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
|
||||
existing = SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
|
||||
ext_id = generate_id()
|
||||
response = self.client.put(
|
||||
reverse(
|
||||
@@ -209,10 +214,51 @@ class TestSCIMUsers(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_user_update_patch(self):
|
||||
"""Test user update (patch)"""
|
||||
user = create_test_user()
|
||||
existing = SCIMSourceUser.objects.create(
|
||||
source=self.source,
|
||||
user=user,
|
||||
external_id=uuid4(),
|
||||
attributes={
|
||||
"userName": generate_id(),
|
||||
},
|
||||
)
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-users",
|
||||
kwargs={
|
||||
"source_slug": self.source.slug,
|
||||
"user_id": str(user.uuid),
|
||||
},
|
||||
),
|
||||
data=dumps(
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{
|
||||
"op": "Add",
|
||||
"path": f"{SCIM_URN_USER_ENTERPRISE}:manager",
|
||||
"value": "86b2ed3e-30cd-4881-bb58-c4e910821339",
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
existing.refresh_from_db()
|
||||
self.assertEqual(
|
||||
existing.attributes[SCIM_URN_USER_ENTERPRISE],
|
||||
{"manager": {"value": "86b2ed3e-30cd-4881-bb58-c4e910821339"}},
|
||||
)
|
||||
|
||||
def test_user_delete(self):
|
||||
"""Test user delete"""
|
||||
user = create_test_user()
|
||||
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
|
||||
SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
|
||||
response = self.client.delete(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-users",
|
||||
|
||||
488
authentik/sources/scim/tests/test_users_patch.py
Normal file
488
authentik/sources/scim/tests/test_users_patch.py
Normal file
@@ -0,0 +1,488 @@
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||
from authentik.sources.scim.models import SCIMSource, SCIMSourceUser
|
||||
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||
|
||||
|
||||
class TestSCIMUsersPatch(APITestCase):
|
||||
"""Test SCIM User Patch"""
|
||||
|
||||
def test_add(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Add", "path": "name.givenName", "value": "aqwer"},
|
||||
{"op": "Add", "path": "name.familyName", "value": "qwerqqqq"},
|
||||
{"op": "Add", "path": "name.formatted", "value": "aqwer qwerqqqq"},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"name": {
|
||||
"givenName": "aqwer",
|
||||
"familyName": "qwerqqqq",
|
||||
"formatted": "aqwer qwerqqqq",
|
||||
},
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_add_no_path(self):
|
||||
"""Test add patch with no path set"""
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Add", "value": {"externalId": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "aqwer",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_replace(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Replace", "path": "name", "value": {"givenName": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"name": {
|
||||
"givenName": "aqwer",
|
||||
},
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_replace_no_path(self):
|
||||
"""Test value replace with no path"""
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Replace", "value": {"externalId": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "aqwer",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_remove(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Remove", "path": "name", "value": {"givenName": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"name": {
|
||||
"givenName": "aqwer",
|
||||
},
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_large(self):
|
||||
"""Large amount of patch operations"""
|
||||
req = {
|
||||
"Operations": [
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "emails[primary eq true].value",
|
||||
"value": "dandre_kling@wintheiser.info",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "phoneNumbers[primary eq true].value",
|
||||
"value": "72-634-1548",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "phoneNumbers[primary eq true].display",
|
||||
"value": "72-634-1548",
|
||||
},
|
||||
{"op": "replace", "path": "ims[primary eq true].value", "value": "GXSGJKWGHVVS"},
|
||||
{"op": "replace", "path": "ims[primary eq true].display", "value": "IMCHDKUQIPYB"},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "photos[primary eq true].display",
|
||||
"value": "TWAWLHHSUNIV",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].formatted",
|
||||
"value": "TMINZQAJQDCL",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].streetAddress",
|
||||
"value": "081 Wisoky Key",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].locality",
|
||||
"value": "DPFASBZRPMDP",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].region",
|
||||
"value": "WHSTJSPIPTCF",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].postalCode",
|
||||
"value": "ko28 1qa",
|
||||
},
|
||||
{"op": "replace", "path": "addresses[primary eq true].country", "value": "Taiwan"},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "entitlements[primary eq true].value",
|
||||
"value": "NGBJMUYZVVBX",
|
||||
},
|
||||
{"op": "replace", "path": "roles[primary eq true].value", "value": "XEELVFMMWCVM"},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "x509Certificates[primary eq true].value",
|
||||
"value": "UYISMEDOXUZY",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"value": {
|
||||
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
|
||||
"name.formatted": "Dell",
|
||||
"name.familyName": "Gay",
|
||||
"name.givenName": "Kyler",
|
||||
"name.middleName": "Hannah",
|
||||
"name.honorificPrefix": "Cassie",
|
||||
"name.honorificSuffix": "Yolanda",
|
||||
"displayName": "DPRLIJSFQMTL",
|
||||
"nickName": "BKSPMIRMFBTI",
|
||||
"title": "NBZCOAXVYJUY",
|
||||
"userType": "ZGJMYZRUORZE",
|
||||
"preferredLanguage": "as-IN",
|
||||
"locale": "JLOJHLPWZODG",
|
||||
"timezone": "America/Argentina/Rio_Gallegos",
|
||||
"active": True,
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:employeeNumber": "PDFWRRZBQOHB",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:costCenter": "HACMZWSEDOTQ",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:organization": "LXVHJUOLNCLS",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:division": "JASVTPKPBPMG",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:department": "GMSBFLMNPABY",
|
||||
},
|
||||
},
|
||||
],
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"active": True,
|
||||
"addresses": [
|
||||
{
|
||||
"primary": "true",
|
||||
"formatted": "BLJMCNXHYLZK",
|
||||
"streetAddress": "7801 Jacobs Fork",
|
||||
"locality": "HZJBJWFAKXDD",
|
||||
"region": "GJXCXPMIIKWK",
|
||||
"postalCode": "pv82 8ua",
|
||||
"country": "India",
|
||||
}
|
||||
],
|
||||
"displayName": "KEFXCHKHAFOT",
|
||||
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
|
||||
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
|
||||
"externalId": "448d2786-7bf6-4e03-a4ef-64cbaf162fa7",
|
||||
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
|
||||
"locale": "PJNYJHWJILTI",
|
||||
"name": {
|
||||
"formatted": "Ladarius",
|
||||
"familyName": "Manley",
|
||||
"givenName": "Mazie",
|
||||
"middleName": "Vernon",
|
||||
"honorificPrefix": "Melyssa",
|
||||
"honorificSuffix": "Demarcus",
|
||||
},
|
||||
"nickName": "HTPKOXMWZKHL",
|
||||
"phoneNumbers": [
|
||||
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
|
||||
],
|
||||
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
|
||||
"preferredLanguage": "wae",
|
||||
"profileUrl": "HPSEOIPXMGOH",
|
||||
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"timezone": "America/Indiana/Petersburg",
|
||||
"title": "EJWFXLHNHMCD",
|
||||
SCIM_URN_USER_ENTERPRISE: {
|
||||
"employeeNumber": "XHDMEJUURJNR",
|
||||
"costCenter": "RXUYBXOTRCZH",
|
||||
"organization": "CEXWXMBRYAHN",
|
||||
"division": "XMPFMDCLRKCW",
|
||||
"department": "BKMNJVMCJUYS",
|
||||
"manager": "PNGSGXLYVWMV",
|
||||
},
|
||||
"userName": "imelda.auer@kshlerin.co.uk",
|
||||
"userType": "PZFXORVSUAPU",
|
||||
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"active": True,
|
||||
"addresses": [
|
||||
{
|
||||
"primary": "true",
|
||||
"formatted": "BLJMCNXHYLZK",
|
||||
"streetAddress": "7801 Jacobs Fork",
|
||||
"locality": "HZJBJWFAKXDD",
|
||||
"region": "GJXCXPMIIKWK",
|
||||
"postalCode": "pv82 8ua",
|
||||
"country": "India",
|
||||
}
|
||||
],
|
||||
"displayName": "DPRLIJSFQMTL",
|
||||
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
|
||||
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
|
||||
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
|
||||
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
|
||||
"locale": "JLOJHLPWZODG",
|
||||
"name": {
|
||||
"formatted": "Dell",
|
||||
"familyName": "Gay",
|
||||
"givenName": "Kyler",
|
||||
"middleName": "Hannah",
|
||||
"honorificPrefix": "Cassie",
|
||||
"honorificSuffix": "Yolanda",
|
||||
},
|
||||
"nickName": "BKSPMIRMFBTI",
|
||||
"phoneNumbers": [
|
||||
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
|
||||
],
|
||||
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
|
||||
"preferredLanguage": "as-IN",
|
||||
"profileUrl": "HPSEOIPXMGOH",
|
||||
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"timezone": "America/Argentina/Rio_Gallegos",
|
||||
"title": "NBZCOAXVYJUY",
|
||||
SCIM_URN_USER_ENTERPRISE: {
|
||||
"employeeNumber": "PDFWRRZBQOHB",
|
||||
"costCenter": "HACMZWSEDOTQ",
|
||||
"organization": "LXVHJUOLNCLS",
|
||||
"division": "JASVTPKPBPMG",
|
||||
"department": "GMSBFLMNPABY",
|
||||
"manager": "PNGSGXLYVWMV",
|
||||
},
|
||||
"userName": "imelda.auer@kshlerin.co.uk",
|
||||
"userType": "ZGJMYZRUORZE",
|
||||
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
|
||||
},
|
||||
)
|
||||
|
||||
def test_schema_urn_manager(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{
|
||||
"op": "Add",
|
||||
"value": {
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager": "foo"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": {
|
||||
"manager": {"value": "foo"}
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""SCIM Utils"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.paginator import Page, Paginator
|
||||
@@ -21,6 +22,7 @@ from authentik.core.sources.mapper import SourceMapper
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
from authentik.sources.scim.views.v2.auth import SCIMTokenAuth
|
||||
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
|
||||
|
||||
SCIM_CONTENT_TYPE = "application/scim+json"
|
||||
|
||||
@@ -54,6 +56,13 @@ class SCIMView(APIView):
|
||||
def get_authenticators(self):
|
||||
return [SCIMTokenAuth(self)]
|
||||
|
||||
def remove_excluded_attributes(self, data: dict):
|
||||
"""Remove attributes specified in excludedAttributes"""
|
||||
excluded: str = self.request.query_params.get("excludedAttributes", "")
|
||||
for key in excluded.split(","):
|
||||
data.pop(key.strip(), None)
|
||||
return data
|
||||
|
||||
def filter_parse(self, request: Request):
|
||||
"""Parse the path of a Patch Operation"""
|
||||
path = request.query_params.get("filter")
|
||||
@@ -103,6 +112,12 @@ class SCIMObjectView(SCIMView):
|
||||
# a source attribute before
|
||||
self.mapper = SourceMapper(self.source)
|
||||
self.manager = self.mapper.get_manager(self.model, ["data"])
|
||||
for key, value in kwargs.items():
|
||||
if key.endswith("_id"):
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise SCIMNotFoundError("Invalid ID") from None
|
||||
|
||||
def build_object_properties(self, data: dict[str, Any]) -> dict[str, Any | dict[str, Any]]:
|
||||
return self.mapper.build_object_properties(
|
||||
|
||||
@@ -17,6 +17,7 @@ from authentik.core.models import Group, User
|
||||
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
|
||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
|
||||
from authentik.sources.scim.models import SCIMSourceGroup
|
||||
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
||||
from authentik.sources.scim.views.v2.exceptions import (
|
||||
SCIMConflictError,
|
||||
@@ -35,11 +36,12 @@ class GroupsView(SCIMObjectView):
|
||||
payload = SCIMGroupModel(
|
||||
schemas=[SCIM_GROUP_SCHEMA],
|
||||
id=str(scim_group.group.pk),
|
||||
externalId=scim_group.id,
|
||||
externalId=scim_group.external_id,
|
||||
displayName=scim_group.group.name,
|
||||
members=[],
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"lastModified": scim_group.last_update,
|
||||
"location": self.request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -54,7 +56,11 @@ class GroupsView(SCIMObjectView):
|
||||
for member in scim_group.group.users.order_by("pk"):
|
||||
member: User
|
||||
payload.members.append(GroupMember(value=str(member.uuid)))
|
||||
return payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload.update(scim_group.attributes)
|
||||
return self.remove_excluded_attributes(
|
||||
SCIMGroupModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
|
||||
)
|
||||
|
||||
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
|
||||
"""List Group handler"""
|
||||
@@ -81,7 +87,7 @@ class GroupsView(SCIMObjectView):
|
||||
)
|
||||
|
||||
@atomic
|
||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
|
||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict, apply_members=True):
|
||||
"""Partial update a group"""
|
||||
properties = self.build_object_properties(data)
|
||||
|
||||
@@ -94,7 +100,7 @@ class GroupsView(SCIMObjectView):
|
||||
|
||||
group.update_attributes(properties)
|
||||
|
||||
if "members" in data:
|
||||
if "members" in data and apply_members:
|
||||
query = Q()
|
||||
for _member in data.get("members", []):
|
||||
try:
|
||||
@@ -105,14 +111,18 @@ class GroupsView(SCIMObjectView):
|
||||
query |= Q(uuid=member.value)
|
||||
if query:
|
||||
group.users.set(User.objects.filter(query))
|
||||
data["members"] = self._convert_members(group)
|
||||
if not connection:
|
||||
connection, _ = SCIMSourceGroup.objects.get_or_create(
|
||||
connection, _ = SCIMSourceGroup.objects.update_or_create(
|
||||
external_id=data.get("externalId") or str(uuid4()),
|
||||
source=self.source,
|
||||
group=group,
|
||||
attributes=data,
|
||||
id=data.get("externalId") or str(uuid4()),
|
||||
defaults={
|
||||
"attributes": data,
|
||||
},
|
||||
)
|
||||
else:
|
||||
connection.external_id = data.get("externalId", connection.external_id)
|
||||
connection.attributes = data
|
||||
connection.save()
|
||||
return connection
|
||||
@@ -139,6 +149,12 @@ class GroupsView(SCIMObjectView):
|
||||
connection = self.update_group(connection, request.data)
|
||||
return Response(self.group_to_scim(connection), status=200)
|
||||
|
||||
def _convert_members(self, group: Group):
|
||||
users = []
|
||||
for user in group.users.all().order_by("uuid"):
|
||||
users.append({"value": str(user.uuid)})
|
||||
return sorted(users, key=lambda u: u["value"])
|
||||
|
||||
@atomic
|
||||
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
|
||||
"""Patch group handler"""
|
||||
@@ -171,6 +187,13 @@ class GroupsView(SCIMObjectView):
|
||||
query |= Q(uuid=member["value"])
|
||||
if query:
|
||||
connection.group.users.remove(*User.objects.filter(query))
|
||||
patcher = SCIMPatchProcessor()
|
||||
patched_data = patcher.apply_patches(
|
||||
connection.attributes, request.data.get("Operations", [])
|
||||
)
|
||||
patched_data["members"] = self._convert_members(connection.group)
|
||||
if patched_data != connection.attributes:
|
||||
self.update_group(connection, patched_data, apply_members=False)
|
||||
return Response(self.group_to_scim(connection), status=200)
|
||||
|
||||
@atomic
|
||||
|
||||
@@ -33,9 +33,7 @@ class ServiceProviderConfigView(SCIMView):
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
|
||||
"authenticationSchemes": auth_schemas,
|
||||
# We only support patch for groups currently, so don't broadly advertise it.
|
||||
# Implementations that require Group patch will use it regardless of this flag.
|
||||
"patch": {"supported": False},
|
||||
"patch": {"supported": True},
|
||||
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
|
||||
"filter": {
|
||||
"supported": True,
|
||||
|
||||
@@ -15,6 +15,7 @@ from authentik.core.models import User
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
|
||||
from authentik.sources.scim.models import SCIMSourceUser
|
||||
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
||||
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
|
||||
|
||||
@@ -29,7 +30,7 @@ class UsersView(SCIMObjectView):
|
||||
payload = SCIMUserModel(
|
||||
schemas=[SCIM_USER_SCHEMA],
|
||||
id=str(scim_user.user.uuid),
|
||||
externalId=scim_user.id,
|
||||
externalId=scim_user.external_id,
|
||||
userName=scim_user.user.username,
|
||||
name=Name(
|
||||
formatted=scim_user.user.name,
|
||||
@@ -44,8 +45,7 @@ class UsersView(SCIMObjectView):
|
||||
meta={
|
||||
"resourceType": "User",
|
||||
"created": scim_user.user.date_joined,
|
||||
# TODO: use events to find last edit?
|
||||
"lastModified": scim_user.user.date_joined,
|
||||
"lastModified": scim_user.last_update,
|
||||
"location": self.request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-users",
|
||||
@@ -59,7 +59,9 @@ class UsersView(SCIMObjectView):
|
||||
)
|
||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload.update(scim_user.attributes)
|
||||
return final_payload
|
||||
return self.remove_excluded_attributes(
|
||||
SCIMUserModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
|
||||
)
|
||||
|
||||
def get(self, request: Request, user_id: str | None = None, **kwargs) -> Response:
|
||||
"""List User handler"""
|
||||
@@ -101,13 +103,16 @@ class UsersView(SCIMObjectView):
|
||||
user.update_attributes(properties)
|
||||
|
||||
if not connection:
|
||||
connection, _ = SCIMSourceUser.objects.get_or_create(
|
||||
connection, _ = SCIMSourceUser.objects.update_or_create(
|
||||
external_id=data.get("externalId") or str(uuid4()),
|
||||
source=self.source,
|
||||
user=user,
|
||||
attributes=data,
|
||||
id=data.get("externalId") or str(uuid4()),
|
||||
defaults={
|
||||
"attributes": data,
|
||||
},
|
||||
)
|
||||
else:
|
||||
connection.external_id = data.get("externalId", connection.external_id)
|
||||
connection.attributes = data
|
||||
connection.save()
|
||||
return connection
|
||||
@@ -127,6 +132,18 @@ class UsersView(SCIMObjectView):
|
||||
connection = self.update_user(None, request.data)
|
||||
return Response(self.user_to_scim(connection), status=201)
|
||||
|
||||
def patch(self, request: Request, user_id: str, **kwargs):
|
||||
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
||||
if not connection:
|
||||
raise SCIMNotFoundError("User not found.")
|
||||
patcher = SCIMPatchProcessor()
|
||||
patched_data = patcher.apply_patches(
|
||||
connection.attributes, request.data.get("Operations", [])
|
||||
)
|
||||
if patched_data != connection.attributes:
|
||||
self.update_user(connection, patched_data)
|
||||
return Response(self.user_to_scim(connection), status=200)
|
||||
|
||||
def put(self, request: Request, user_id: str, **kwargs) -> Response:
|
||||
"""Update user handler"""
|
||||
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
||||
|
||||
@@ -13,7 +13,6 @@ from authentik.flows.exceptions import StageInvalidException
|
||||
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.stages.authenticator.models import SideChannelDevice
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
@@ -160,9 +159,8 @@ class EmailDevice(SerializerModel, SideChannelDevice):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=_("Exception occurred while rendering E-mail template"),
|
||||
error=exception_to_string(exc),
|
||||
template=stage.template,
|
||||
).from_http(self.request)
|
||||
).with_exception(exc).from_http(self.request)
|
||||
raise StageInvalidException from exc
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -17,7 +17,6 @@ from authentik.flows.challenge import (
|
||||
from authentik.flows.exceptions import StageInvalidException
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.lib.utils.email import mask_email
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.stages.authenticator_email.models import (
|
||||
AuthenticatorEmailStage,
|
||||
@@ -100,9 +99,8 @@ class AuthenticatorEmailStageView(ChallengeStageView):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=_("Exception occurred while rendering E-mail template"),
|
||||
error=exception_to_string(exc),
|
||||
template=stage.template,
|
||||
).from_http(self.request)
|
||||
).with_exception(exc).from_http(self.request)
|
||||
raise StageInvalidException from exc
|
||||
|
||||
def _has_email(self) -> str | None:
|
||||
|
||||
@@ -19,7 +19,6 @@ from authentik.events.models import Event, EventAction, NotificationWebhookMappi
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.stages.authenticator.models import SideChannelDevice
|
||||
|
||||
@@ -142,10 +141,9 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="Error sending SMS",
|
||||
exc=exception_to_string(exc),
|
||||
status_code=response.status_code,
|
||||
body=response.text,
|
||||
).set_user(device.user).save()
|
||||
).with_exception(exc).set_user(device.user).save()
|
||||
if response.status_code >= HttpResponseBadRequest.status_code:
|
||||
raise ValidationError(response.text) from None
|
||||
raise
|
||||
|
||||
@@ -9,7 +9,7 @@ from django.http.response import Http404
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as __
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.fields import CharField, DateTimeField
|
||||
from rest_framework.fields import CharField, ChoiceField, DateTimeField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from structlog.stdlib import get_logger
|
||||
from webauthn import options_to_json
|
||||
@@ -18,7 +18,7 @@ from webauthn.authentication.verify_authentication_response import verify_authen
|
||||
from webauthn.helpers import parse_authentication_credential_json
|
||||
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
|
||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse, InvalidJSONStructure
|
||||
from webauthn.helpers.structs import UserVerificationRequirement
|
||||
from webauthn.helpers.structs import PublicKeyCredentialType, UserVerificationRequirement
|
||||
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import Application, User
|
||||
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
||||
class DeviceChallenge(PassiveSerializer):
|
||||
"""Single device challenge"""
|
||||
|
||||
device_class = CharField()
|
||||
device_class = ChoiceField(choices=DeviceClasses.choices)
|
||||
device_uid = CharField()
|
||||
challenge = JSONDictField()
|
||||
last_used = DateTimeField(allow_null=True)
|
||||
@@ -157,6 +157,12 @@ def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -
|
||||
request = stage_view.request
|
||||
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
|
||||
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
|
||||
if "MinuteMaid" in request.META.get("HTTP_USER_AGENT", ""):
|
||||
# Workaround for Android sign-in, when signing into Google Workspace on android while
|
||||
# adding the account to the system (not in Chrome), for some reason `type` is not set
|
||||
# so in that case we fall back to `public-key`
|
||||
# since that's the only option we support anyways
|
||||
data.setdefault("type", PublicKeyCredentialType.PUBLIC_KEY)
|
||||
try:
|
||||
credential = parse_authentication_credential_json(data)
|
||||
except InvalidJSONStructure as exc:
|
||||
|
||||
@@ -173,6 +173,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
|
||||
{
|
||||
"auth_method": "auth_mfa",
|
||||
"auth_method_args": {
|
||||
"known_device": False,
|
||||
"mfa_devices": [
|
||||
{
|
||||
"app": "authentik_stages_authenticator_duo",
|
||||
@@ -180,7 +181,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
|
||||
"name": "",
|
||||
"pk": duo_device.pk,
|
||||
}
|
||||
]
|
||||
],
|
||||
},
|
||||
"http_request": {
|
||||
"args": {},
|
||||
|
||||
@@ -153,13 +153,13 @@ class AuthenticatorValidateStageTests(FlowTestCase):
|
||||
plan.append_stage(stage)
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
"device_class": "static",
|
||||
"device_class": DeviceClasses.STATIC,
|
||||
"device_uid": "1",
|
||||
"challenge": {},
|
||||
"last_used": now(),
|
||||
},
|
||||
{
|
||||
"device_class": "totp",
|
||||
"device_class": DeviceClasses.TOTP,
|
||||
"device_uid": "2",
|
||||
"challenge": {},
|
||||
"last_used": now(),
|
||||
@@ -172,7 +172,7 @@ class AuthenticatorValidateStageTests(FlowTestCase):
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
data={
|
||||
"selected_challenge": {
|
||||
"device_class": "baz",
|
||||
"device_class": DeviceClasses.WEBAUTHN,
|
||||
"device_uid": "quox",
|
||||
"challenge": {},
|
||||
"last_used": None,
|
||||
|
||||
@@ -162,7 +162,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
@@ -282,7 +282,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
@@ -359,7 +359,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
@@ -441,7 +441,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
"device_class": device.__class__.__name__.lower().replace("device", ""),
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -21,7 +21,6 @@ from authentik.flows.models import FlowDesignation, FlowToken
|
||||
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.flows.views.executor import QS_KEY_TOKEN, QS_QUERY
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.stages.email.flow import pickle_flow_token_for_email
|
||||
from authentik.stages.email.models import EmailStage
|
||||
@@ -129,9 +128,8 @@ class EmailStageView(ChallengeStageView):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=_("Exception occurred while rendering E-mail template"),
|
||||
error=exception_to_string(exc),
|
||||
template=current_stage.template,
|
||||
).from_http(self.request)
|
||||
).with_exception(exc).from_http(self.request)
|
||||
raise StageInvalidException from exc
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
@@ -145,7 +143,7 @@ class EmailStageView(ChallengeStageView):
|
||||
messages.success(request, _("Successfully verified Email."))
|
||||
if self.executor.current_stage.activate_user_on_success:
|
||||
user.is_active = True
|
||||
user.save()
|
||||
user.save(update_fields=["is_active"])
|
||||
return self.executor.stage_ok()
|
||||
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
|
||||
self.logger.debug("No pending user")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Prompt Stage Logic"""
|
||||
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Callable
|
||||
from email.policy import Policy
|
||||
from types import MethodType
|
||||
from typing import Any
|
||||
@@ -190,10 +190,11 @@ class ListPolicyEngine(PolicyEngine):
|
||||
self.__list = policies
|
||||
self.use_cache = False
|
||||
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
for policy in self.__list:
|
||||
def bindings(self):
|
||||
for idx, policy in enumerate(self.__list):
|
||||
yield PolicyBinding(
|
||||
policy=policy,
|
||||
order=idx,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class TestPromptStage(FlowTestCase):
|
||||
"""Test challenge_response validation"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
expr = "False"
|
||||
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
||||
expr_policy = ExpressionPolicy.objects.create(name=generate_id(), expression=expr)
|
||||
self.stage.validation_policies.set([expr_policy])
|
||||
self.stage.save()
|
||||
challenge_response = PromptChallengeResponse(
|
||||
@@ -222,6 +222,18 @@ class TestPromptStage(FlowTestCase):
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), False)
|
||||
|
||||
def test_invalid_challenge_multiple(self):
|
||||
"""Test challenge_response validation (multiple policies)"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
expr_policy1 = ExpressionPolicy.objects.create(name=generate_id(), expression="False")
|
||||
expr_policy2 = ExpressionPolicy.objects.create(name=generate_id(), expression="False")
|
||||
self.stage.validation_policies.set([expr_policy1, expr_policy2])
|
||||
self.stage.save()
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), False)
|
||||
|
||||
def test_valid_challenge_request(self):
|
||||
"""Test a request with valid challenge_response data"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
@@ -234,7 +246,7 @@ class TestPromptStage(FlowTestCase):
|
||||
"return request.context['prompt_data']['password_prompt'] "
|
||||
"== request.context['prompt_data']['password2_prompt']"
|
||||
)
|
||||
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
||||
expr_policy = ExpressionPolicy.objects.create(name=generate_id(), expression=expr)
|
||||
self.stage.validation_policies.set([expr_policy])
|
||||
self.stage.save()
|
||||
challenge_response = PromptChallengeResponse(
|
||||
|
||||
@@ -18,6 +18,7 @@ class UserLoginStageSerializer(StageSerializer):
|
||||
"remember_me_offset",
|
||||
"network_binding",
|
||||
"geoip_binding",
|
||||
"remember_device",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.contrib.auth.views import redirect_to_login
|
||||
from django.http.request import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Session
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
@@ -89,7 +90,7 @@ class BoundSessionMiddleware(SessionMiddleware):
|
||||
|
||||
def recheck_session(self, request: HttpRequest):
|
||||
"""Check if a session is still valid with a changed IP"""
|
||||
last_ip = request.session.get(request.session.model.Keys.LAST_IP)
|
||||
last_ip = request.session.get(Session.Keys.LAST_IP)
|
||||
new_ip = ClientIPMiddleware.get_client_ip(request)
|
||||
# Check changed IP
|
||||
if new_ip == last_ip:
|
||||
@@ -109,7 +110,7 @@ class BoundSessionMiddleware(SessionMiddleware):
|
||||
if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
|
||||
# Only set the last IP in the session if there's a binding specified
|
||||
# (== basically requires the user to be logged in)
|
||||
request.session[request.session.model.Keys.LAST_IP] = new_ip
|
||||
request.session[Session.Keys.LAST_IP] = new_ip
|
||||
|
||||
@staticmethod
|
||||
def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.1.11 on 2025-06-18 16:21
|
||||
|
||||
import authentik.lib.utils.time
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_stages_user_login", "0006_userloginstage_geoip_binding_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="userloginstage",
|
||||
name="remember_device",
|
||||
field=models.TextField(
|
||||
default="days=30",
|
||||
help_text="When set to a non-zero value, authentik will save a cookie with a longer expiry,to remember the device the user is logging in from. (Format: hours=-1;minutes=-2;seconds=-3)",
|
||||
validators=[authentik.lib.utils.time.timedelta_string_validator],
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -63,6 +63,15 @@ class UserLoginStage(Stage):
|
||||
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||
),
|
||||
)
|
||||
remember_device = models.TextField(
|
||||
default="days=30",
|
||||
validators=[timedelta_string_validator],
|
||||
help_text=_(
|
||||
"When set to a non-zero value, authentik will save a cookie with a longer expiry,"
|
||||
"to remember the device the user is logging in from. "
|
||||
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Login stage logic"""
|
||||
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
from hashlib import sha256
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth import login
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.translation import gettext as _
|
||||
from jwt import PyJWTError, decode, encode
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
|
||||
from authentik.core.models import Session, User
|
||||
from authentik.core.models import AuthenticatedSession, Session, User
|
||||
from authentik.events.middleware import audit_ignore
|
||||
from authentik.flows.challenge import ChallengeResponse, WithUserInfoChallenge
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
@@ -16,12 +19,20 @@ from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.password import BACKEND_INBUILT
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
|
||||
from authentik.stages.password.stage import (
|
||||
PLAN_CONTEXT_AUTHENTICATION_BACKEND,
|
||||
PLAN_CONTEXT_METHOD_ARGS,
|
||||
)
|
||||
from authentik.stages.user_login.middleware import (
|
||||
SESSION_KEY_BINDING_GEO,
|
||||
SESSION_KEY_BINDING_NET,
|
||||
)
|
||||
from authentik.stages.user_login.models import UserLoginStage
|
||||
from authentik.tenants.utils import get_unique_identifier
|
||||
|
||||
COOKIE_NAME_KNOWN_DEVICE = "authentik_device"
|
||||
|
||||
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE = "known_device"
|
||||
|
||||
|
||||
class UserLoginChallenge(WithUserInfoChallenge):
|
||||
@@ -78,12 +89,63 @@ class UserLoginStageView(ChallengeStageView):
|
||||
self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding
|
||||
self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding
|
||||
|
||||
def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse:
|
||||
"""Attach the currently pending user to the current session"""
|
||||
# FIXME: identical function in authenticator_validate
|
||||
@property
|
||||
def cookie_jwt_key(self) -> str:
|
||||
"""Signing key for Known-device Cookie for this stage"""
|
||||
return sha256(
|
||||
f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii")
|
||||
).hexdigest()
|
||||
|
||||
def set_known_device_cookie(self, user: User):
|
||||
"""Set a cookie, valid longer than the session, which denotes that this user
|
||||
has logged in on this device before."""
|
||||
delta = timedelta_from_string(self.executor.current_stage.remember_device)
|
||||
response = self.executor.stage_ok()
|
||||
if delta.total_seconds() < 1:
|
||||
return response
|
||||
expiry = datetime.now() + delta
|
||||
cookie_payload = {
|
||||
"sub": user.uid,
|
||||
"exp": expiry.timestamp(),
|
||||
}
|
||||
cookie = encode(cookie_payload, self.cookie_jwt_key)
|
||||
response.set_cookie(
|
||||
COOKIE_NAME_KNOWN_DEVICE,
|
||||
cookie,
|
||||
expires=expiry,
|
||||
path=settings.SESSION_COOKIE_PATH,
|
||||
domain=settings.SESSION_COOKIE_DOMAIN,
|
||||
samesite=settings.SESSION_COOKIE_SAMESITE,
|
||||
)
|
||||
return response
|
||||
|
||||
def is_known_device(self, user: User):
|
||||
"""Returns `True` if the login happened on a "known" device, by the same user."""
|
||||
client_ip = ClientIPMiddleware.get_client_ip(self.request)
|
||||
if AuthenticatedSession.objects.filter(session__last_ip=client_ip).exists():
|
||||
return True
|
||||
if COOKIE_NAME_KNOWN_DEVICE not in self.request.COOKIES:
|
||||
return False
|
||||
try:
|
||||
payload = decode(
|
||||
self.request.COOKIES[COOKIE_NAME_KNOWN_DEVICE], self.cookie_jwt_key, ["HS256"]
|
||||
)
|
||||
if payload["sub"] == user.uid:
|
||||
return True
|
||||
return False
|
||||
except (PyJWTError, ValueError, TypeError) as exc:
|
||||
self.logger.info("eh", exc=exc)
|
||||
return False
|
||||
|
||||
def do_login(self, request: HttpRequest, remember: bool | None = None) -> HttpResponse:
|
||||
"""Attach the currently pending user to the current session.
|
||||
`remember` Argument should be `None` if not configured, otherwise set to `True`/`False`
|
||||
representative of the user's choice."""
|
||||
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
|
||||
message = _("No Pending user to login.")
|
||||
messages.error(request, message)
|
||||
self.logger.debug(message)
|
||||
self.logger.warning(message)
|
||||
return self.executor.stage_invalid()
|
||||
backend = self.executor.plan.context.get(
|
||||
PLAN_CONTEXT_AUTHENTICATION_BACKEND, BACKEND_INBUILT
|
||||
@@ -91,8 +153,13 @@ class UserLoginStageView(ChallengeStageView):
|
||||
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||
if not user.is_active:
|
||||
self.logger.warning("User is not active, login will not work.")
|
||||
delta = self.set_session_duration(remember)
|
||||
delta = self.set_session_duration(bool(remember))
|
||||
self.set_session_ip()
|
||||
# Check if the login request is coming from a known device
|
||||
self.executor.plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||
self.executor.plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, self.is_known_device(user)
|
||||
)
|
||||
# the `user_logged_in` signal will update the user to write the `last_login` field
|
||||
# which we don't want to log as we already have a dedicated login event
|
||||
with audit_ignore():
|
||||
@@ -112,4 +179,6 @@ class UserLoginStageView(ChallengeStageView):
|
||||
Session.objects.filter(
|
||||
authenticatedsession__user=user,
|
||||
).exclude(session_key=self.request.session.session_key).delete()
|
||||
if remember is None:
|
||||
return self.set_known_device_cookie(user)
|
||||
return self.executor.stage_ok()
|
||||
|
||||
@@ -8,17 +8,18 @@ from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, Session
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_flow, create_test_user
|
||||
from authentik.flows.markers import StageMarker
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.user_login.middleware import (
|
||||
SESSION_KEY_BINDING_NET,
|
||||
BoundSessionMiddleware,
|
||||
SessionBindingBroken,
|
||||
logout_extra,
|
||||
@@ -31,7 +32,7 @@ class TestUserLoginStage(FlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.user = create_test_admin_user()
|
||||
self.user = create_test_user()
|
||||
|
||||
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
|
||||
self.stage = UserLoginStage.objects.create(name="login")
|
||||
@@ -247,3 +248,21 @@ class TestUserLoginStage(FlowTestCase):
|
||||
request.session = self.client.session
|
||||
request.user = self.user
|
||||
logout_extra(request, cm.exception)
|
||||
|
||||
def test_session_binding_broken(self):
|
||||
"""Test session binding"""
|
||||
self.client.force_login(self.user)
|
||||
session = self.client.session
|
||||
session[Session.Keys.LAST_IP] = "192.0.2.1"
|
||||
session[SESSION_KEY_BINDING_NET] = NetworkBinding.BIND_ASN_NETWORK_IP
|
||||
session.save()
|
||||
|
||||
res = self.client.get(reverse("authentik_api:user-me"))
|
||||
self.assertEqual(res.status_code, 302)
|
||||
self.assertEqual(
|
||||
res.url,
|
||||
reverse(
|
||||
"authentik_flows:default-authentication",
|
||||
)
|
||||
+ f"?{NEXT_ARG_NAME}={reverse("authentik_api:user-me")}",
|
||||
)
|
||||
|
||||
@@ -4754,6 +4754,7 @@
|
||||
"add_token",
|
||||
"change_token",
|
||||
"delete_token",
|
||||
"set_token_key",
|
||||
"view_token",
|
||||
"view_token_key"
|
||||
]
|
||||
@@ -4890,6 +4891,7 @@
|
||||
"authentik_core.preview_user",
|
||||
"authentik_core.remove_user_from_group",
|
||||
"authentik_core.reset_user_password",
|
||||
"authentik_core.set_token_key",
|
||||
"authentik_core.unassign_user_permissions",
|
||||
"authentik_core.view_application",
|
||||
"authentik_core.view_applicationentitlement",
|
||||
@@ -9536,6 +9538,7 @@
|
||||
"authentik_core.preview_user",
|
||||
"authentik_core.remove_user_from_group",
|
||||
"authentik_core.reset_user_password",
|
||||
"authentik_core.set_token_key",
|
||||
"authentik_core.unassign_user_permissions",
|
||||
"authentik_core.view_application",
|
||||
"authentik_core.view_applicationentitlement",
|
||||
@@ -10958,6 +10961,7 @@
|
||||
"enum": [
|
||||
"apple",
|
||||
"openidconnect",
|
||||
"entraid",
|
||||
"azuread",
|
||||
"discord",
|
||||
"facebook",
|
||||
@@ -15546,6 +15550,12 @@
|
||||
],
|
||||
"title": "Geoip binding",
|
||||
"description": "Bind sessions created by this stage to the configured GeoIP location"
|
||||
},
|
||||
"remember_device": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Remember device",
|
||||
"description": "When set to a non-zero value, authentik will save a cookie with a longer expiry,to remember the device the user is logging in from. (Format: hours=-1;minutes=-2;seconds=-3)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
|
||||
17
go.mod
17
go.mod
@@ -6,18 +6,18 @@ require (
|
||||
beryju.io/ldap v0.1.0
|
||||
github.com/avast/retry-go/v4 v4.6.1
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/getsentry/sentry-go v0.34.0
|
||||
github.com/getsentry/sentry-go v0.34.1
|
||||
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
||||
github.com/go-ldap/ldap/v3 v3.4.11
|
||||
github.com/go-openapi/runtime v0.28.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/golang-jwt/jwt/v5 v5.2.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/handlers v1.5.2
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/securecookie v1.1.2
|
||||
github.com/gorilla/sessions v1.4.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/grafana/pyroscope-go v1.2.2
|
||||
github.com/grafana/pyroscope-go v1.2.3
|
||||
github.com/jellydator/ttlcache/v3 v3.4.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
|
||||
@@ -29,10 +29,10 @@ require (
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/wwt/guac v1.3.2
|
||||
goauthentik.io/api/v3 v3.2025063.1
|
||||
goauthentik.io/api/v3 v3.2025063.5
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.15.0
|
||||
golang.org/x/sync v0.16.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
|
||||
)
|
||||
@@ -77,9 +77,10 @@ require (
|
||||
go.opentelemetry.io/otel v1.24.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/net v0.40.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
36
go.sum
@@ -71,8 +71,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
||||
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4=
|
||||
github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
|
||||
github.com/getsentry/sentry-go v0.34.1 h1:HSjc1C/OsnZttohEPrrqKH42Iud0HuLCXpv8cU1pWcw=
|
||||
github.com/getsentry/sentry-go v0.34.1/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
@@ -115,8 +115,8 @@ github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+Gr
|
||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||
github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58=
|
||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.3 h1:kkGXqQOBSDDWRhWNXTFpqGSCMyh/PLnqUvMGJPDJDs0=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.3/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
@@ -180,8 +180,8 @@ github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2e
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||
github.com/grafana/pyroscope-go v1.2.3 h1:Rp8mjqqGqmRDvV6XYmuedUAv7wVnQJK/M1pBt6uNwxU=
|
||||
github.com/grafana/pyroscope-go v1.2.3/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||
@@ -298,16 +298,16 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
|
||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
goauthentik.io/api/v3 v3.2025063.1 h1:zvKhZTESgMY/SNiLuTs7G0YleBnev1v7+S9Xd6PZ9bc=
|
||||
goauthentik.io/api/v3 v3.2025063.1/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
goauthentik.io/api/v3 v3.2025063.5 h1:j5el9/qI/72Q5x5QAiMzgQTswMj2TK3h74OaBcFEtkI=
|
||||
goauthentik.io/api/v3 v3.2025063.5/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -366,8 +366,8 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
|
||||
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -384,8 +384,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -413,15 +413,15 @@ golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
@@ -28,7 +28,7 @@ func (fe *FlowExecutor) solveChallenge_AuthenticatorValidate(challenge *api.Chal
|
||||
var deviceChallenge *api.DeviceChallenge
|
||||
inner := api.NewAuthenticatorValidationChallengeResponseRequest()
|
||||
for _, devCh := range challenge.AuthenticatorValidationChallenge.DeviceChallenges {
|
||||
if devCh.DeviceClass == string(api.DEVICECLASSESENUM_DUO) {
|
||||
if devCh.DeviceClass == api.DEVICECLASSESENUM_DUO {
|
||||
deviceChallenge = &devCh
|
||||
devId, err := strconv.ParseInt(deviceChallenge.DeviceUid, 10, 32)
|
||||
if err != nil {
|
||||
@@ -38,8 +38,8 @@ func (fe *FlowExecutor) solveChallenge_AuthenticatorValidate(challenge *api.Chal
|
||||
inner.SelectedChallenge = (*api.DeviceChallengeRequest)(deviceChallenge)
|
||||
inner.Duo = &devId32
|
||||
}
|
||||
if devCh.DeviceClass == string(api.DEVICECLASSESENUM_STATIC) ||
|
||||
devCh.DeviceClass == string(api.DEVICECLASSESENUM_TOTP) {
|
||||
if devCh.DeviceClass == api.DEVICECLASSESENUM_STATIC ||
|
||||
devCh.DeviceClass == api.DEVICECLASSESENUM_TOTP {
|
||||
// Only use code-based devices if we have a code in the entered password,
|
||||
// and we haven't selected a push device yet
|
||||
if deviceChallenge == nil && fe.getAnswer(StageAuthenticatorValidate) != "" {
|
||||
|
||||
@@ -100,6 +100,9 @@ elif [[ "$1" == "healthcheck" ]]; then
|
||||
elif [[ "$1" == "dump_config" ]]; then
|
||||
shift
|
||||
exec python -m authentik.lib.config $@
|
||||
elif [[ "$1" == "support" ]]; then
|
||||
wait_for_db
|
||||
exec python -m lifecycle.support
|
||||
elif [[ "$1" == "debug" ]]; then
|
||||
exec sleep infinity
|
||||
else
|
||||
|
||||
9
lifecycle/aws/package-lock.json
generated
9
lifecycle/aws/package-lock.json
generated
@@ -9,7 +9,7 @@
|
||||
"version": "0.0.0",
|
||||
"license": "MIT",
|
||||
"devDependencies": {
|
||||
"aws-cdk": "^2.1020.1",
|
||||
"aws-cdk": "^2.1021.0",
|
||||
"cross-env": "^7.0.3"
|
||||
},
|
||||
"engines": {
|
||||
@@ -17,10 +17,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/aws-cdk": {
|
||||
"version": "2.1020.1",
|
||||
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1020.1.tgz",
|
||||
"integrity": "sha512-4UG9qzf6ZSDjINubcukPZChVj6PvDJAHiURAw0jYSkUhObPkX7Zo9uNUIlXzrM+hpB2N2jwRKY9b3sN+KDbtAQ==",
|
||||
"version": "2.1021.0",
|
||||
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1021.0.tgz",
|
||||
"integrity": "sha512-kE557b4N9UFWax+7km3R6D56o4tGhpzOks/lRDugaoC8su3mocLCXJhb954b/IRl0ipnbZnY/Sftq+RQ/sxivg==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"cdk": "bin/cdk"
|
||||
},
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"node": ">=20"
|
||||
},
|
||||
"devDependencies": {
|
||||
"aws-cdk": "^2.1020.1",
|
||||
"aws-cdk": "^2.1021.0",
|
||||
"cross-env": "^7.0.3"
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user