Compare commits

...

15 Commits

Author SHA1 Message Date
Jens Langhammer
c8f57bd361 sigh
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-19 00:20:30 +02:00
Jens Langhammer
4ff7ebc12f fix more
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-19 00:16:21 +02:00
Jens Langhammer
e90b114afc fix id_token
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-19 00:13:42 +02:00
Jens Langhammer
047a0cfa13 fix redirect url actually
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 23:21:13 +02:00
Jens Langhammer
27c1292763 fix proxy
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 23:16:40 +02:00
Jens Langhammer
d4b097b631 fix oauth redirect url
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 22:53:41 +02:00
Jens Langhammer
aab93d09d8 fix a couple more things
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 22:34:47 +02:00
Jens Langhammer
92d82db7b7 fix
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 22:01:04 +02:00
Jens Langhammer
e693fe4937 fix outposts
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 21:53:08 +02:00
Jens L.
c9f05476f3 Potential fix for pull request finding 'Unused local variable'
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Signed-off-by: Jens L. <jens@beryju.org>
2026-04-18 21:45:54 +02:00
Jens Langhammer
335e1cb5f9 remove dacite
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 21:38:41 +02:00
Jens Langhammer
f23d5bd7d7 replace enterprise
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 21:37:24 +02:00
Jens Langhammer
4dd952df3f providers/oauth2: replace
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 21:35:48 +02:00
Jens Langhammer
1c0e373f87 outposts: replace
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 21:32:53 +02:00
Jens Langhammer
6441c1dcf4 blueprints: replace
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2026-04-18 21:17:40 +02:00
47 changed files with 504 additions and 458 deletions

View File

@@ -320,7 +320,7 @@ ci--meta-debug:
node --version || echo "No node installed"
ci-lint-mypy: ci--meta-debug
$(UV) run mypy --strict $(PY_SOURCES)
$(UV) run mypy --show-traceback --strict $(PY_SOURCES)
ci-lint-black: ci--meta-debug
$(UV) run black --check $(PY_SOURCES)

View File

@@ -4,7 +4,6 @@ from glob import glob
from pathlib import Path
import django.contrib.postgres.fields
from dacite.core import from_dict
from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
@@ -33,7 +32,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, db_alias, path: Path):
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
meta = None
if metadata:
meta = from_dict(BlueprintMetadata, metadata)
meta = BlueprintMetadata.model_validate(metadata)
if meta.labels.get(LABEL_AUTHENTIK_INSTANTIATE, "").lower() == "false":
return
if not instance:

View File

@@ -3,7 +3,6 @@
from collections import OrderedDict
from collections.abc import Generator, Iterable, Mapping
from copy import copy
from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum
from functools import reduce
from json import JSONDecodeError, loads
@@ -15,12 +14,14 @@ from uuid import UUID
from deepmerge import always_merger
from django.apps import apps
from django.db.models import Model, Q
from pydantic import BaseModel, ConfigDict, Field
from rest_framework.exceptions import ValidationError
from rest_framework.fields import Field
from rest_framework.fields import Field as DRFField
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from yaml import SafeDumper, SafeLoader, ScalarNode, SequenceNode
from authentik.blueprints.v1.meta.registry import MetaResult
from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.policies.models import PolicyBindingModel
@@ -38,7 +39,7 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]:
data = dict(serializer.data)
for field_name, _field in serializer.fields.items():
_field: Field
_field: DRFField
if field_name not in data:
continue
if _field.read_only:
@@ -48,11 +49,12 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]:
return data
@dataclass
class BlueprintEntryState:
class BlueprintEntryState(BaseModel):
"""State of a single instance"""
instance: Model | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
instance: Model | MetaResult | None = None
class BlueprintEntryDesiredState(Enum):
@@ -64,32 +66,35 @@ class BlueprintEntryDesiredState(Enum):
MUST_CREATED = "must_created"
@dataclass
class BlueprintEntryPermission:
class BlueprintEntryPermission(BaseModel):
"""Describe object-level permissions"""
model_config = ConfigDict(arbitrary_types_allowed=True)
permission: str | YAMLTag
user: int | YAMLTag | None = field(default=None)
role: str | YAMLTag | None = field(default=None)
user: int | YAMLTag | None = Field(default=None)
role: str | UUID | YAMLTag | None = Field(default=None)
@dataclass
class BlueprintEntry:
class BlueprintEntry(BaseModel):
"""Single entry of a blueprint"""
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str | YAMLTag
state: BlueprintEntryDesiredState | YAMLTag = field(default=BlueprintEntryDesiredState.PRESENT)
conditions: list[Any] = field(default_factory=list)
identifiers: dict[str, Any] = field(default_factory=dict)
attrs: dict[str, Any] | None = field(default_factory=dict)
permissions: list[BlueprintEntryPermission] = field(default_factory=list)
state: BlueprintEntryDesiredState | YAMLTag = Field(default=BlueprintEntryDesiredState.PRESENT)
conditions: list[Any] = Field(default_factory=list)
identifiers: dict[str, Any] = Field(default_factory=dict)
attrs: dict[str, Any] | None = Field(default_factory=dict)
permissions: list[BlueprintEntryPermission] = Field(default_factory=list)
id: str | None = None
_state: BlueprintEntryState = field(default_factory=BlueprintEntryState)
_state: BlueprintEntryState
def __post_init__(self, *args, **kwargs) -> None:
def model_post_init(self, __context: Any) -> None:
self.__tag_contexts: list[YAMLTagContext] = []
self._state = BlueprintEntryState()
@staticmethod
def from_model(model: SerializerModel, *extra_identifier_names: str) -> BlueprintEntry:
@@ -178,23 +183,23 @@ class BlueprintEntry:
return all(self.tag_resolver(self.conditions, blueprint))
@dataclass
class BlueprintMetadata:
class BlueprintMetadata(BaseModel):
"""Optional blueprint metadata"""
name: str
labels: dict[str, str] = field(default_factory=dict)
labels: dict[str, str] = Field(default_factory=dict)
@dataclass
class Blueprint:
class Blueprint(BaseModel):
"""Dataclass used for a full export"""
version: int = field(default=1)
entries: list[BlueprintEntry] | dict[str, list[BlueprintEntry]] = field(default_factory=list)
context: dict = field(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True)
metadata: BlueprintMetadata | None = field(default=None)
version: int = Field(default=1)
entries: list[BlueprintEntry] | dict[str, list[BlueprintEntry]] = Field(default_factory=list)
context: dict = Field(default_factory=dict)
metadata: BlueprintMetadata | None = Field(default=None)
def iter_entries(self) -> Iterable[BlueprintEntry]:
if isinstance(self.entries, dict):
@@ -208,7 +213,7 @@ class YAMLTag:
"""Base class for all YAML Tags"""
def __repr__(self) -> str:
return str(self.resolve(BlueprintEntry(""), Blueprint()))
return str(self.resolve(BlueprintEntry(model=""), Blueprint()))
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
"""Implement yaml tag logic"""
@@ -696,18 +701,8 @@ class BlueprintDumper(SafeDumper):
return True
def represent(self, data) -> None:
if is_dataclass(data):
def factory(items):
final_dict = dict(items)
# Remove internal state variables
final_dict.pop("_state", None)
# Future-proof to only remove the ID if we don't set a value
if "id" in final_dict and final_dict.get("id") is None:
final_dict.pop("id")
return final_dict
data = asdict(data, dict_factory=factory)
if isinstance(data, BaseModel):
data = data.model_dump(mode="json", exclude_none=True)
return super().represent(data)

View File

@@ -4,9 +4,6 @@ from contextlib import contextmanager
from copy import deepcopy
from typing import Any
from dacite.config import Config
from dacite.core import from_dict
from dacite.exceptions import DaciteError
from deepmerge import always_merger
from django.contrib.auth.models import Permission
from django.contrib.contenttypes.models import ContentType
@@ -16,6 +13,7 @@ from django.db.models.query_utils import Q
from django.db.transaction import atomic
from django.db.utils import IntegrityError
from guardian.models import RoleObjectPermission
from pydantic import ValidationError as PydanticValidationError
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer, Serializer
from structlog.stdlib import BoundLogger, get_logger
@@ -158,10 +156,8 @@ class Importer:
"""Parse YAML string and create blueprint importer from it"""
import_dict = load(yaml_input, BlueprintLoader)
try:
_import = from_dict(
Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState])
)
except DaciteError as exc:
_import = Blueprint.model_validate(import_dict)
except PydanticValidationError as exc:
raise EntryInvalidError from exc
return Importer(_import, context)
@@ -399,7 +395,7 @@ class Importer:
self.logger.debug("Updated model", model=instance)
if "pk" in entry.identifiers:
self.__pk_map[entry.identifiers["pk"]] = instance.pk
entry._state = BlueprintEntryState(instance)
entry._state = BlueprintEntryState(instance=instance)
self._apply_permissions(instance, entry)
elif state == BlueprintEntryDesiredState.ABSENT:
instance: Model | None = serializer.instance

View File

@@ -1,12 +1,10 @@
"""v1 blueprints tasks"""
from dataclasses import asdict, dataclass, field
from hashlib import sha512
from pathlib import Path
from sys import platform
from uuid import UUID
from dacite.core import from_dict
from django.conf import settings
from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.text import slugify
@@ -14,6 +12,7 @@ from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from dramatiq.middleware import Middleware
from pydantic import BaseModel, Field
from structlog.stdlib import get_logger
from watchdog.events import (
FileCreatedEvent,
@@ -45,15 +44,14 @@ from authentik.tenants.models import Tenant
LOGGER = get_logger()
@dataclass
class BlueprintFile:
class BlueprintFile(BaseModel):
"""Basic info about a blueprint file"""
path: str
version: int
hash: str
last_m: int
meta: BlueprintMetadata | None = field(default=None)
meta: BlueprintMetadata | None = Field(default=None)
class BlueprintWatcherMiddleware(Middleware):
@@ -115,7 +113,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
def blueprints_find_dict():
blueprints = []
for blueprint in blueprints_find():
blueprints.append(sanitize_dict(asdict(blueprint)))
blueprints.append(sanitize_dict(blueprint.model_dump(mode="json")))
return blueprints
@@ -142,8 +140,10 @@ def blueprints_find() -> list[BlueprintFile]:
LOGGER.warning("invalid blueprint version", version=version, path=str(rel_path))
continue
file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(str(rel_path), version, file_hash, int(path.stat().st_mtime))
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprint = BlueprintFile(
path=str(rel_path), version=version, hash=file_hash, last_m=int(path.stat().st_mtime)
)
blueprint.meta = BlueprintMetadata.model_validate(metadata) if metadata else None
blueprints.append(blueprint)
return blueprints
@@ -205,7 +205,7 @@ def apply_blueprint(instance_pk: UUID):
file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer.from_string(blueprint_content, instance.context)
if importer.blueprint.metadata:
instance.metadata = asdict(importer.blueprint.metadata)
instance.metadata = importer.blueprint.metadata.model_dump(mode="json")
valid, logs = importer.validate()
if not valid:
instance.status = BlueprintInstanceStatus.ERROR

View File

@@ -22,7 +22,11 @@ class TestApplicationsAPI(APITestCase):
self.user = create_test_admin_user()
self.provider = OAuth2Provider.objects.create(
name="test",
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://some-other-domain")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://some-other-domain"
)
],
authorization_flow=create_test_flow(),
)
self.allowed: Application = Application.objects.create(

View File

@@ -291,7 +291,9 @@ class TestCrypto(APITestCase):
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=keypair,
)
response = self.client.get(
@@ -323,7 +325,9 @@ class TestCrypto(APITestCase):
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=keypair,
)
response = self.client.get(

View File

@@ -2,7 +2,6 @@
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import lru_cache
@@ -10,12 +9,13 @@ from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import DaciteError, from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from jwt.algorithms import ECAlgorithm
from pydantic import BaseModel, Field
from pydantic import ValidationError as PydanticValidationError
from rest_framework.exceptions import ValidationError
from rest_framework.fields import (
ChoiceField,
@@ -60,8 +60,7 @@ class LicenseFlags(Enum):
NON_PRODUCTION = "non_production"
@dataclass
class LicenseSummary:
class LicenseSummary(BaseModel):
"""Internal representation of a license summary"""
internal_users: int
@@ -81,8 +80,7 @@ class LicenseSummarySerializer(PassiveSerializer):
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags)))
@dataclass
class LicenseKey:
class LicenseKey(BaseModel):
"""License JWT claims"""
aud: str
@@ -91,7 +89,7 @@ class LicenseKey:
name: str
internal_users: int = 0
external_users: int = 0
license_flags: list[LicenseFlags] = field(default_factory=list)
license_flags: list[LicenseFlags] = Field(default_factory=list)
@staticmethod
def validate(jwt: str, check_expiry=True) -> LicenseKey:
@@ -118,8 +116,7 @@ class LicenseKey:
# authentik will change its license generation to `algorithm="ES384"` in 2026.
# TODO: remove this when the last incompatible license runs out.
ECAlgorithm._validate_curve = lambda *_: True
body = from_dict(
LicenseKey,
body = LicenseKey.model_validate(
decode(
jwt,
our_cert.public_key(),
@@ -140,7 +137,13 @@ class LicenseKey:
@staticmethod
def get_total() -> LicenseKey:
"""Get a summarized version of all (not expired) licenses"""
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
total = LicenseKey(
aud=get_license_aud(),
exp=0,
name="Summarized license",
internal_users=0,
external_users=0,
)
for lic in License.objects.all():
if lic.is_valid:
total.internal_users += lic.internal_users
@@ -219,7 +222,7 @@ class LicenseKey:
external_user_count=self.get_external_user_count(),
status=self.status(),
)
summary = asdict(self.summary())
summary = self.summary().model_dump(mode="json")
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return usage
@@ -243,7 +246,7 @@ class LicenseKey:
if not summary:
return LicenseKey.get_total().summary()
try:
return from_dict(LicenseSummary, summary)
except DaciteError:
return LicenseSummary.model_validate(summary)
except PydanticValidationError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()

View File

@@ -1,6 +1,3 @@
import json
from dataclasses import asdict
from django.urls import reverse
from django.utils import timezone
from rest_framework.test import APITestCase
@@ -80,11 +77,7 @@ class TestSSFAuth(APITestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
res = self.client.post(

View File

@@ -1,10 +1,9 @@
"""Outpost API Views"""
from dacite.core import from_dict
from dacite.exceptions import DaciteError
from django_filters.filters import ModelMultipleChoiceFilter
from django_filters.filterset import FilterSet
from drf_spectacular.utils import extend_schema
from pydantic import ValidationError as PydanticValidationError
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, CharField, DateTimeField, SerializerMethodField
@@ -90,9 +89,9 @@ class OutpostSerializer(ModelSerializer):
def validate_config(self, config) -> dict:
"""Check that the config has all required fields"""
try:
parsed = from_dict(OutpostConfig, config)
parsed = OutpostConfig.model_validate(config)
timedelta_string_validator(parsed.refresh_interval)
except DaciteError as exc:
except PydanticValidationError as exc:
raise ValidationError(f"Failed to validate config: {str(exc)}") from exc
return config

View File

@@ -1,7 +1,5 @@
"""Outpost API Views"""
from dataclasses import asdict
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema
from kubernetes.client.configuration import Configuration
@@ -79,8 +77,8 @@ class ServiceConnectionViewSet(
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
def state(self, request: Request, pk: str) -> Response:
"""Get the service connection's state"""
connection = self.get_object()
return Response(asdict(connection.state))
connection: OutpostServiceConnection = self.get_object()
return Response(connection.state.model_dump(mode="json"))
class DockerServiceConnectionSerializer(ServiceConnectionSerializer):

View File

@@ -1,6 +1,5 @@
"""Outpost websocket handler"""
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import IntEnum
from hashlib import sha256
@@ -10,11 +9,10 @@ from uuid import UUID
from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer
from dacite.core import from_dict
from dacite.data import Data
from django.db import connection
from django.http.request import QueryDict
from guardian.shortcuts import get_objects_for_user
from pydantic import BaseModel, Field
from structlog.stdlib import BoundLogger, get_logger
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
@@ -50,12 +48,11 @@ class WebsocketMessageInstruction(IntEnum):
SESSION_END = 4
@dataclass(slots=True)
class WebsocketMessage:
class WebsocketMessage(BaseModel):
"""Complete Websocket Message that is being sent"""
instruction: int
args: dict[str, Any] = field(default_factory=dict)
args: dict[str, Any] = Field(default_factory=dict)
class OutpostConsumer(JsonWebsocketConsumer):
@@ -118,8 +115,8 @@ class OutpostConsumer(JsonWebsocketConsumer):
expected=self.outpost.config.kubernetes_replicas,
).dec()
def receive_json(self, content: Data, **kwargs):
msg = from_dict(WebsocketMessage, content)
def receive_json(self, content: Any, **kwargs):
msg = WebsocketMessage.model_validate(content)
if not self.outpost:
raise DenyConnection()
@@ -146,29 +143,29 @@ class OutpostConsumer(JsonWebsocketConsumer):
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK)
self.send_json(asdict(response))
self.send_json(response.model_dump(mode="json"))
def event_update(self, event): # pragma: no cover
"""Event handler which is called by post_save signals, Send update instruction"""
self.send_json(
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE).model_dump(
mode="json"
)
)
def event_session_end(self, event):
"""Event handler which is called when a session is ended"""
self.send_json(
asdict(
WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event)
)
WebsocketMessage(
instruction=WebsocketMessageInstruction.SESSION_END, args=event
).model_dump(mode="json")
)
def event_provider_specific(self, event):
"""Event handler which can be called by provider-specific
implementations to send specific messages to the outpost"""
self.send_json(
asdict(
WebsocketMessage(
instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
)
)
WebsocketMessage(
instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
).model_dump(mode="json")
)

View File

@@ -2,12 +2,10 @@
import re
import ssl
from dataclasses import asdict
from json import dumps
from typing import TYPE_CHECKING, TypeVar
import urllib3
from dacite.core import from_dict
from django.http import HttpResponseNotFound
from django.utils.text import slugify
from jsonpatch import JsonPatchConflict, JsonPatchException, JsonPatchTestFailed, apply_patch
@@ -15,6 +13,7 @@ from kubernetes.client import ApiClient, V1ObjectMeta
from kubernetes.client.exceptions import ApiException, OpenApiException
from kubernetes.client.models.v1_deployment import V1Deployment
from kubernetes.client.models.v1_pod import V1Pod
from pydantic import BaseModel
from requests import Response
from structlog.stdlib import get_logger
from urllib3.exceptions import HTTPError
@@ -97,11 +96,10 @@ class KubernetesObjectReconciler[T]:
"""Get patched reference object"""
reference = self.get_reference_object()
patch = self.get_patch()
try:
if isinstance(reference, BaseModel):
json = reference.model_dump(mode="json")
else:
json = self.api_client.sanitize_for_serialization(reference)
# Custom objects will not be known to the clients openapi types
except AttributeError:
json = asdict(reference)
try:
ref = json
if patch is not None:
@@ -111,12 +109,10 @@ class KubernetesObjectReconciler[T]:
mock_response = Response()
mock_response.data = dumps(ref)
try:
if isinstance(reference, BaseModel):
result = reference.__class__.model_validate(ref)
else:
result = self.api_client.deserialize(mock_response, reference.__class__.__name__)
# Custom objects will not be known to the clients openapi types
except AttributeError:
result = from_dict(reference.__class__, data=ref)
return result
def up(self):
@@ -191,10 +187,10 @@ class KubernetesObjectReconciler[T]:
patch = self.get_patch()
if patch is not None:
try:
if isinstance(reference, BaseModel):
current_json = reference.model_dump(mode="json")
else:
current_json = self.api_client.sanitize_for_serialization(current)
except AttributeError:
current_json = asdict(current)
try:
if apply_patch(current_json, patch) != current_json:
raise NeedsUpdate()

View File

@@ -1,10 +1,9 @@
"""Kubernetes Prometheus ServiceMonitor Reconciler"""
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING
from dacite.core import from_dict
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi
from pydantic import BaseModel, Field
from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler
@@ -15,23 +14,20 @@ if TYPE_CHECKING:
from authentik.outposts.controllers.kubernetes import KubernetesController
@dataclass(slots=True)
class PrometheusServiceMonitorSpecEndpoint:
class PrometheusServiceMonitorSpecEndpoint(BaseModel):
"""Prometheus ServiceMonitor endpoint spec"""
port: str
path: str = field(default="/metrics")
path: str = Field(default="/metrics")
@dataclass(slots=True)
class PrometheusServiceMonitorSpecSelector:
class PrometheusServiceMonitorSpecSelector(BaseModel):
"""Prometheus ServiceMonitor selector spec"""
matchLabels: dict
@dataclass(slots=True)
class PrometheusServiceMonitorSpec:
class PrometheusServiceMonitorSpec(BaseModel):
"""Prometheus ServiceMonitor spec"""
endpoints: list[PrometheusServiceMonitorSpecEndpoint]
@@ -39,17 +35,15 @@ class PrometheusServiceMonitorSpec:
selector: PrometheusServiceMonitorSpecSelector
@dataclass(slots=True)
class PrometheusServiceMonitorMetadata:
class PrometheusServiceMonitorMetadata(BaseModel):
"""Prometheus ServiceMonitor metadata"""
name: str
namespace: str
labels: dict = field(default_factory=dict)
labels: dict = Field(default_factory=dict)
@dataclass(slots=True)
class PrometheusServiceMonitor:
class PrometheusServiceMonitor(BaseModel):
"""Prometheus ServiceMonitor"""
apiVersion: str
@@ -59,7 +53,7 @@ class PrometheusServiceMonitor:
def to_dict(self):
"""`to_dict` to conform to how the kubernetes client converts objects to dicts"""
return asdict(self)
return self.model_dump(mode="json")
CRD_NAME = "servicemonitors.monitoring.coreos.com"
@@ -132,7 +126,7 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
version=CRD_VERSION,
plural=CRD_PLURAL,
namespace=self.namespace,
body=asdict(reference),
body=reference.model_dump(mode="json"),
field_manager=FIELD_MANAGER,
)
@@ -146,15 +140,14 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
)
def retrieve(self) -> PrometheusServiceMonitor:
return from_dict(
PrometheusServiceMonitor,
return PrometheusServiceMonitor.model_validate(
self.api.get_namespaced_custom_object(
group=CRD_GROUP,
version=CRD_VERSION,
namespace=self.namespace,
plural=CRD_PLURAL,
name=self.name,
),
)
)
def update(self, current: PrometheusServiceMonitor, reference: PrometheusServiceMonitor):
@@ -164,6 +157,6 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
namespace=self.namespace,
plural=CRD_PLURAL,
name=self.name,
body=asdict(reference),
body=reference.model_dump(mode="json"),
field_manager=FIELD_MANAGER,
)

View File

@@ -1,12 +1,10 @@
"""Outpost models"""
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any
from uuid import uuid4
from dacite.core import from_dict
from django.contrib.auth.models import Permission
from django.core.cache import cache
from django.db import IntegrityError, models, transaction
@@ -14,6 +12,7 @@ from django.db.models.base import Model
from django.utils.translation import gettext_lazy as _
from model_utils.managers import InheritanceManager
from packaging.version import Version, parse
from pydantic import BaseModel, ConfigDict, Field
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
@@ -49,8 +48,7 @@ class ServiceConnectionInvalid(SentryIgnoredException):
"""Exception raised when a Service Connection has invalid parameters"""
@dataclass
class OutpostConfig:
class OutpostConfig(BaseModel):
"""Configuration an outpost uses to configure it self"""
# update website/docs/add-secure-apps/outposts/_config.md
@@ -60,28 +58,28 @@ class OutpostConfig:
authentik_host_browser: str = ""
log_level: str = CONFIG.get("log_level")
object_naming_template: str = field(default="ak-outpost-%(name)s")
object_naming_template: str = Field(default="ak-outpost-%(name)s")
refresh_interval: str = "minutes=5"
container_image: str | None = field(default=None)
container_image: str | None = Field(default=None)
docker_network: str | None = field(default=None)
docker_map_ports: bool = field(default=True)
docker_labels: dict[str, str] | None = field(default=None)
docker_network: str | None = Field(default=None)
docker_map_ports: bool = Field(default=True)
docker_labels: dict[str, str] | None = Field(default=None)
kubernetes_replicas: int = field(default=1)
kubernetes_namespace: str = field(default_factory=get_namespace)
kubernetes_ingress_annotations: dict[str, str] = field(default_factory=dict)
kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls")
kubernetes_ingress_class_name: str | None = field(default=None)
kubernetes_ingress_path_type: str | None = field(default=None)
kubernetes_httproute_annotations: dict[str, str] = field(default_factory=dict)
kubernetes_httproute_parent_refs: list[dict[str, str]] = field(default_factory=list)
kubernetes_service_type: str = field(default="ClusterIP")
kubernetes_disabled_components: list[str] = field(default_factory=list)
kubernetes_image_pull_secrets: list[str] = field(default_factory=list)
kubernetes_json_patches: dict[str, list[dict[str, Any]]] | None = field(default=None)
kubernetes_disable_x509_strict: bool = field(default=False)
kubernetes_replicas: int = Field(default=1)
kubernetes_namespace: str = Field(default_factory=get_namespace)
kubernetes_ingress_annotations: dict[str, str] = Field(default_factory=dict)
kubernetes_ingress_secret_name: str = Field(default="authentik-outpost-tls")
kubernetes_ingress_class_name: str | None = Field(default=None)
kubernetes_ingress_path_type: str | None = Field(default=None)
kubernetes_httproute_annotations: dict[str, str] = Field(default_factory=dict)
kubernetes_httproute_parent_refs: list[dict[str, str]] = Field(default_factory=list)
kubernetes_service_type: str = Field(default="ClusterIP")
kubernetes_disabled_components: list[str] = Field(default_factory=list)
kubernetes_image_pull_secrets: list[str] = Field(default_factory=list)
kubernetes_json_patches: dict[str, list[dict[str, Any]]] | None = Field(default=None)
kubernetes_disable_x509_strict: bool = Field(default=False)
class OutpostModel(Model):
@@ -104,13 +102,12 @@ class OutpostType(models.TextChoices):
RAC = "rac"
def default_outpost_config(host: str | None = None):
def default_outpost_config(host: str | None = None) -> dict[str, Any]:
"""Get default outpost config"""
return asdict(OutpostConfig(authentik_host=host or ""))
return OutpostConfig(authentik_host=host or "").model_dump(mode="json")
@dataclass
class OutpostServiceConnectionState:
class OutpostServiceConnectionState(BaseModel):
"""State of an Outpost Service Connection"""
version: str
@@ -152,7 +149,7 @@ class OutpostServiceConnection(ScheduledModel, models.Model):
state = cache.get(self.state_key, None)
if not state:
outpost_service_connection_monitor.send_with_options(args=(self.pk,), rel_obj=self)
return OutpostServiceConnectionState("", False)
return OutpostServiceConnectionState(version="", healthy=False)
return state
@property
@@ -292,12 +289,12 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
@property
def config(self) -> OutpostConfig:
"""Load config as OutpostConfig object"""
return from_dict(OutpostConfig, self._config)
return OutpostConfig.model_validate(self._config)
@config.setter
def config(self, value):
def config(self, value: OutpostConfig):
"""Dump config into json"""
self._config = asdict(value)
self._config = value.model_dump(mode="json")
@property
def state_cache_prefix(self) -> str:
@@ -457,23 +454,27 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
verbose_name_plural = _("Outposts")
@dataclass
class OutpostState:
class OutpostState(BaseModel):
"""Outpost instance state, last_seen and version"""
uid: str
last_seen: datetime | None = field(default=None)
version: str | None = field(default=None)
version_should: Version = field(default=OUR_VERSION)
build_hash: str = field(default="")
golang_version: str = field(default="")
openssl_enabled: bool = field(default=False)
openssl_version: str = field(default="")
fips_enabled: bool = field(default=False)
hostname: str = field(default="")
args: dict = field(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True)
_outpost: Outpost | None = field(default=None)
uid: str
last_seen: datetime | None = Field(default=None)
version: str | None = Field(default=None)
version_should: Version = Field(default=OUR_VERSION)
build_hash: str = Field(default="")
golang_version: str = Field(default="")
openssl_enabled: bool = Field(default=False)
openssl_version: str = Field(default="")
fips_enabled: bool = Field(default=False)
hostname: str = Field(default="")
args: dict = Field(default_factory=dict)
_outpost: Outpost | None
def model_post_init(self, context):
self._outpost = None
@property
def version_outdated(self) -> bool:
@@ -505,7 +506,7 @@ class OutpostState:
if isinstance(data, str):
cache.delete(key)
data = default_data
state = from_dict(OutpostState, data)
state = OutpostState.model_validate(data)
state._outpost = outpost
return state
@@ -513,7 +514,7 @@ class OutpostState:
def save(self, timeout=OUTPOST_HELLO_INTERVAL):
"""Save current state to cache"""
full_key = f"{self._outpost.state_cache_prefix}/{self.uid}"
return cache.set(full_key, asdict(self), timeout=timeout)
return cache.set(full_key, self.model_dump(), timeout=timeout)
def delete(self):
"""Manually delete from cache, used on channel disconnect"""

View File

@@ -1,11 +1,11 @@
"""id_token utils"""
from dataclasses import asdict, dataclass, field
from hashlib import sha256
from typing import TYPE_CHECKING, Any
from django.http import HttpRequest
from django.utils import timezone
from pydantic import BaseModel, Field
from authentik.common.oauth.constants import (
ACR_AUTHENTIK_DEFAULT,
@@ -29,8 +29,7 @@ def hash_session_key(session_key: str) -> str:
return sha256(session_key.encode("ascii")).hexdigest()
@dataclass(slots=True)
class IDToken:
class IDToken(BaseModel):
"""The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be
Authenticated is the ID Token data structure. The ID Token is a security token that contains
Claims about the Authentication of an End-User by an Authorization Server when using a Client,
@@ -71,14 +70,14 @@ class IDToken:
# JWT ID, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.7
jti: str | None = None
claims: dict[str, Any] = field(default_factory=dict)
claims: dict[str, Any] = Field(default_factory=dict)
@staticmethod
def new(
provider: OAuth2Provider, token: BaseGrantModel, request: HttpRequest, **kwargs
) -> IDToken:
"""Create ID Token"""
id_token = IDToken(provider, token, **kwargs)
id_token = IDToken()
id_token.exp = int(
(token.expires if token.expires is not None else default_token_duration()).timestamp()
)
@@ -140,7 +139,7 @@ class IDToken:
def to_dict(self) -> dict[str, Any]:
"""Convert dataclass to dict, and update with keys from `claims`"""
id_dict = asdict(self)
id_dict = self.model_dump(mode="json")
# All items without a value should be removed instead being set to None/null
# https://openid.net/specs/openid-connect-core-1_0.html#JSONSerialization
for key in list(id_dict.keys()):

View File

@@ -19,7 +19,7 @@ def migrate_redirect_uris(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
mode = RedirectURIMatchingMode.STRICT
if old == "*" or old == ".*":
mode = RedirectURIMatchingMode.REGEX
uris.append(asdict(RedirectURI(mode, url=old)))
uris.append(asdict(RedirectURI(matching_mode=mode, url=old)))
provider._redirect_uris = uris
provider.save()

View File

@@ -3,7 +3,6 @@
import base64
import binascii
import json
from dataclasses import asdict, dataclass
from functools import cached_property
from hashlib import sha256
from typing import TYPE_CHECKING, Any
@@ -17,8 +16,6 @@ from cryptography.hazmat.primitives.asymmetric.ec import (
)
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from dacite import Config
from dacite.core import from_dict
from django.contrib.postgres.indexes import HashIndex
from django.db import models
from django.http import HttpRequest
@@ -29,6 +26,7 @@ from jwcrypto.common import json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jwk import JWK
from jwt import encode
from pydantic import BaseModel
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
@@ -109,8 +107,7 @@ class OAuth2LogoutMethod(models.TextChoices):
FRONTCHANNEL = "frontchannel", _("Front-channel")
@dataclass
class RedirectURI:
class RedirectURI(BaseModel):
"""A single redirect URI entry"""
matching_mode: RedirectURIMatchingMode
@@ -345,16 +342,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
uris = []
for entry in self._redirect_uris:
uris.append(
from_dict(
RedirectURI,
entry,
config=Config(
type_hooks={
RedirectURIMatchingMode: RedirectURIMatchingMode,
RedirectURIType: RedirectURIType,
}
),
)
RedirectURI.model_validate(entry),
)
return uris
@@ -362,7 +350,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
def redirect_uris(self, value: list[RedirectURI]):
cleansed = []
for entry in value:
cleansed.append(asdict(entry))
cleansed.append(entry.model_dump(mode="json"))
self._redirect_uris = cleansed
@property
@@ -550,12 +538,12 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr
from authentik.providers.oauth2.id_token import IDToken
raw_token = json.loads(self._id_token)
return from_dict(IDToken, raw_token)
return IDToken.model_validate(raw_token)
@id_token.setter
def id_token(self, value: IDToken):
self.token = value.to_access_token(self.provider, self)
self._id_token = json.dumps(asdict(value))
self._id_token = json.dumps(value.model_dump(mode="json"))
@property
def at_hash(self):
@@ -603,11 +591,11 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG
from authentik.providers.oauth2.id_token import IDToken
raw_token = json.loads(self._id_token)
return from_dict(IDToken, raw_token)
return IDToken.model_validate(raw_token)
@id_token.setter
def id_token(self, value: IDToken):
self._id_token = json.dumps(asdict(value))
self._id_token = json.dumps(value.model_dump(mode="json"))
@property
def serializer(self) -> Serializer:

View File

@@ -27,7 +27,9 @@ class TestAPI(APITestCase):
self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
@@ -58,8 +60,8 @@ class TestAPI(APITestCase):
"""Test launch_url"""
self.provider.redirect_uris = [
RedirectURI(
RedirectURIMatchingMode.REGEX,
"https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/",
matching_mode=RedirectURIMatchingMode.REGEX,
url="https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/",
),
]
self.provider.save()

View File

@@ -47,7 +47,11 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid/Foo"
)
],
)
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
@@ -73,7 +77,11 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid/Foo"
)
],
)
with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get(
@@ -94,7 +102,11 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
@@ -107,7 +119,11 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
@@ -127,7 +143,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="data:localhost")
],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
@@ -159,7 +177,10 @@ class TestAuthorize(OAuthTestCase):
)
OAuthAuthorizationParams.from_request(request)
provider.refresh_from_db()
self.assertEqual(provider.redirect_uris, [RedirectURI(RedirectURIMatchingMode.STRICT, "+")])
self.assertEqual(
provider.redirect_uris,
[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="+")],
)
def test_invalid_redirect_uri_regex(self):
"""test missing/invalid redirect URI"""
@@ -167,7 +188,11 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.REGEX, url="http://local.invalid?"
)
],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
@@ -187,7 +212,7 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.REGEX, url="+")],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get(
@@ -207,7 +232,7 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.REGEX, url=".+")],
)
request = self.factory.get(
"/",
@@ -226,7 +251,11 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid/Foo"
)
],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
@@ -315,7 +344,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -351,7 +382,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=self.keypair,
)
provider.property_mappings.set(
@@ -421,7 +454,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=self.keypair,
encryption_key=self.keypair,
)
@@ -484,7 +519,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=self.keypair,
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -533,7 +570,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id=generate_id(),
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=self.keypair,
)
provider.property_mappings.set(
@@ -590,7 +629,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id=generate_id(),
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
signing_key=self.keypair,
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
@@ -632,7 +673,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
)
request = self.factory.get(
"/",
@@ -654,7 +697,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost")
],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
@@ -685,7 +730,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -715,7 +762,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -754,7 +803,9 @@ class TestAuthorize(OAuthTestCase):
client_id="test",
authorization_flow=flow,
authentication_flow=auth_flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -780,7 +831,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)

View File

@@ -35,7 +35,9 @@ class TestBackChannelLogout(OAuthTestCase):
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver/callback"),
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver/callback"
),
],
signing_key=self.keypair,
)

View File

@@ -30,19 +30,19 @@ class TestEndSessionView(OAuthTestCase):
invalidation_flow=self.invalidation_flow,
redirect_uris=[
RedirectURI(
RedirectURIMatchingMode.STRICT,
"http://testserver/callback",
RedirectURIType.AUTHORIZATION,
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://testserver/callback",
redirect_uri_type=RedirectURIType.AUTHORIZATION,
),
RedirectURI(
RedirectURIMatchingMode.STRICT,
"http://testserver/logout",
RedirectURIType.LOGOUT,
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://testserver/logout",
redirect_uri_type=RedirectURIType.LOGOUT,
),
RedirectURI(
RedirectURIMatchingMode.REGEX,
r"https://.*\.example\.com/logout",
RedirectURIType.LOGOUT,
matching_mode=RedirectURIMatchingMode.REGEX,
url=r"https://.*\.example\.com/logout",
redirect_uri_type=RedirectURIType.LOGOUT,
),
],
)
@@ -229,9 +229,9 @@ class TestEndSessionAPI(OAuthTestCase):
authorization_flow=create_test_flow(),
redirect_uris=[
RedirectURI(
RedirectURIMatchingMode.STRICT,
"http://testserver/callback",
RedirectURIType.AUTHORIZATION,
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://testserver/callback",
redirect_uri_type=RedirectURIType.AUTHORIZATION,
),
],
)

View File

@@ -1,8 +1,6 @@
"""Test introspect view"""
import json
from base64 import b64encode
from dataclasses import asdict
from django.urls import reverse
from django.utils import timezone
@@ -31,7 +29,7 @@ class TesOAuth2Introspection(OAuthTestCase):
self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
)
self.app = Application.objects.create(
@@ -50,11 +48,7 @@ class TesOAuth2Introspection(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken(iss="foo", sub="bar").model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
@@ -82,11 +76,7 @@ class TesOAuth2Introspection(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken(iss="foo", sub="bar").model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
@@ -126,7 +116,7 @@ class TesOAuth2Introspection(OAuthTestCase):
provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
)
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
@@ -137,11 +127,7 @@ class TesOAuth2Introspection(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
@@ -181,11 +167,7 @@ class TesOAuth2Introspection(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
@@ -206,7 +188,7 @@ class TesOAuth2Introspection(OAuthTestCase):
other_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
client_type=ClientTypes.PUBLIC,
)
@@ -220,11 +202,7 @@ class TesOAuth2Introspection(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken(iss="foo", sub="bar").model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),

View File

@@ -49,7 +49,11 @@ class TestJWKS(OAuthTestCase):
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=create_test_cert(),
)
app = Application.objects.create(name="test", slug="test", provider=provider)
@@ -68,7 +72,11 @@ class TestJWKS(OAuthTestCase):
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
)
app = Application.objects.create(name="test", slug="test", provider=provider)
response = self.client.get(
@@ -82,7 +90,11 @@ class TestJWKS(OAuthTestCase):
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
)
app = Application.objects.create(name="test", slug="test", provider=provider)
@@ -99,7 +111,11 @@ class TestJWKS(OAuthTestCase):
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
encryption_key=create_test_cert(PrivateKeyAlg.ECDSA),
)
@@ -122,7 +138,11 @@ class TestJWKS(OAuthTestCase):
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=cert,
)
app = Application.objects.create(name="test", slug="test", provider=provider)

View File

@@ -1,8 +1,6 @@
"""Test revoke view"""
import json
from base64 import b64encode
from dataclasses import asdict
from django.urls import reverse
from django.utils import timezone
@@ -32,7 +30,7 @@ class TesOAuth2Revoke(OAuthTestCase):
self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
)
self.app = Application.objects.create(
@@ -52,11 +50,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-revoke"),
@@ -75,11 +69,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-revoke"),
@@ -134,11 +124,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
auth_public = b64encode(f"{self.provider.client_id}:{generate_id()}".encode()).decode()
res = self.client.post(
@@ -160,11 +146,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
self.client.logout()
self.assertEqual(AccessToken.objects.including_expired().all().count(), 0)
@@ -185,11 +167,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
session.delete()
self.assertEqual(AccessToken.objects.including_expired().all().count(), 0)
@@ -202,11 +180,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
RefreshToken.objects.create(
provider=self.provider,
@@ -214,11 +188,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
DeviceToken.objects.create(
provider=self.provider,
@@ -239,7 +209,7 @@ class TesOAuth2Revoke(OAuthTestCase):
other_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
client_type=ClientTypes.PUBLIC,
)
@@ -253,11 +223,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-revoke"),
@@ -275,7 +241,7 @@ class TesOAuth2Revoke(OAuthTestCase):
other_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
client_type=ClientTypes.PUBLIC,
)
@@ -289,11 +255,7 @@ class TesOAuth2Revoke(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken().model_dump_json(),
)
auth_public = b64encode(f"{self.provider.client_id}:{generate_id()}".encode()).decode()
res = self.client.post(

View File

@@ -49,7 +49,9 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://TestServer")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://TestServer")
],
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
@@ -76,7 +78,9 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
@@ -97,7 +101,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
@@ -139,7 +147,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
)
# Needs to be assigned to an application for iss to be set
@@ -179,7 +191,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
encryption_key=self.keypair,
)
@@ -210,7 +226,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
)
provider.property_mappings.set(
@@ -271,7 +291,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
)
provider.property_mappings.set(
@@ -328,7 +352,9 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=self.keypair,
)
provider.property_mappings.set(
@@ -400,7 +426,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
refresh_token_threshold="hours=1", # nosec
)
@@ -497,7 +527,11 @@ class TestToken(OAuthTestCase):
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://local.invalid"
)
],
signing_key=self.keypair,
include_claims_in_id_token=True,
)

View File

@@ -53,7 +53,9 @@ class TestTokenClientCredentialsJWTProvider(OAuthTestCase):
self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=self.cert,
)
self.provider.jwt_federation_providers.add(self.other_provider)

View File

@@ -66,7 +66,9 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=self.cert,
)
self.provider.jwt_federation_sources.add(self.source)

View File

@@ -39,7 +39,9 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
self.provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=create_test_cert(),
)
self.provider.property_mappings.set(ScopeMapping.objects.all())

View File

@@ -40,7 +40,9 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
self.provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=create_test_cert(),
)
self.provider.property_mappings.set(ScopeMapping.objects.all())

View File

@@ -43,7 +43,9 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
self.provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=create_test_cert(),
)
self.provider.property_mappings.set(ScopeMapping.objects.all())

View File

@@ -35,7 +35,9 @@ class TestTokenDeviceCode(OAuthTestCase):
self.provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="http://testserver")
],
signing_key=create_test_cert(),
)
self.provider.property_mappings.set(ScopeMapping.objects.all())

View File

@@ -35,7 +35,9 @@ class TestTokenPKCE(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -93,7 +95,9 @@ class TestTokenPKCE(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -149,7 +153,9 @@ class TestTokenPKCE(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
@@ -194,7 +200,9 @@ class TestTokenPKCE(OAuthTestCase):
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")],
redirect_uris=[
RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="foo://localhost")
],
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)

View File

@@ -1,8 +1,5 @@
"""Test userinfo view"""
import json
from dataclasses import asdict
from django.urls import reverse
from django.utils import timezone
@@ -32,7 +29,7 @@ class TestUserinfo(OAuthTestCase):
self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "")],
redirect_uris=[RedirectURI(matching_mode=RedirectURIMatchingMode.STRICT, url="")],
signing_key=create_test_cert(),
)
self.provider.property_mappings.set(ScopeMapping.objects.all())
@@ -46,11 +43,7 @@ class TestUserinfo(OAuthTestCase):
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
_id_token=IDToken(iss="foo", sub="bar").model_dump_json(),
)
def test_userinfo_normal(self):

View File

@@ -201,9 +201,9 @@ class OAuthAuthorizationParams:
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
self.provider.redirect_uris = [
RedirectURI(
RedirectURIMatchingMode.STRICT,
self.redirect_uri,
RedirectURIType.AUTHORIZATION,
matching_mode=RedirectURIMatchingMode.STRICT,
url=self.redirect_uri,
redirect_uri_type=RedirectURIType.AUTHORIZATION,
)
]
self.provider.save()

View File

@@ -1,9 +1,8 @@
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from dacite.core import from_dict
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi, V1ObjectMeta
from pydantic import BaseModel, Field
from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler
@@ -15,14 +14,12 @@ if TYPE_CHECKING:
from authentik.outposts.controllers.kubernetes import KubernetesController
@dataclass(slots=True)
class RouteBackendRef:
class RouteBackendRef(BaseModel):
name: str
port: int
@dataclass(slots=True)
class RouteSpecParentRefs:
class RouteSpecParentRefs(BaseModel):
name: str
sectionName: str | None = None
port: int | None = None
@@ -31,48 +28,41 @@ class RouteSpecParentRefs:
group: str = "gateway.networking.k8s.io"
@dataclass(slots=True)
class HTTPRouteSpecRuleMatchPath:
class HTTPRouteSpecRuleMatchPath(BaseModel):
type: str
value: str
@dataclass(slots=True)
class HTTPRouteSpecRuleMatchHeader:
class HTTPRouteSpecRuleMatchHeader(BaseModel):
name: str
value: str
type: str = "Exact"
@dataclass(slots=True)
class HTTPRouteSpecRuleMatch:
class HTTPRouteSpecRuleMatch(BaseModel):
path: HTTPRouteSpecRuleMatchPath
headers: list[HTTPRouteSpecRuleMatchHeader]
@dataclass(slots=True)
class HTTPRouteSpecRule:
class HTTPRouteSpecRule(BaseModel):
backendRefs: list[RouteBackendRef]
matches: list[HTTPRouteSpecRuleMatch]
@dataclass(slots=True)
class HTTPRouteSpec:
class HTTPRouteSpec(BaseModel):
parentRefs: list[RouteSpecParentRefs]
hostnames: list[str]
rules: list[HTTPRouteSpecRule]
@dataclass(slots=True)
class HTTPRouteMetadata:
class HTTPRouteMetadata(BaseModel):
name: str
namespace: str
annotations: dict = field(default_factory=dict)
labels: dict = field(default_factory=dict)
annotations: dict = Field(default_factory=dict)
labels: dict = Field(default_factory=dict)
@dataclass(slots=True)
class HTTPRoute:
class HTTPRoute(BaseModel):
apiVersion: str
kind: str
metadata: HTTPRouteMetadata
@@ -183,7 +173,7 @@ class HTTPRouteReconciler(KubernetesObjectReconciler):
),
spec=HTTPRouteSpec(
parentRefs=[
from_dict(RouteSpecParentRefs, spec)
RouteSpecParentRefs.model_validate(spec)
for spec in self.controller.outpost.config.kubernetes_httproute_parent_refs
],
hostnames=hostnames,
@@ -197,7 +187,7 @@ class HTTPRouteReconciler(KubernetesObjectReconciler):
version=self.crd_version,
plural=self.crd_plural,
namespace=self.namespace,
body=asdict(reference),
body=reference.model_dump(mode="json"),
field_manager=FIELD_MANAGER,
)
@@ -211,15 +201,14 @@ class HTTPRouteReconciler(KubernetesObjectReconciler):
)
def retrieve(self) -> HTTPRoute:
return from_dict(
HTTPRoute,
return HTTPRoute.model_validate(
self.api.get_namespaced_custom_object(
group=self.crd_group,
version=self.crd_version,
plural=self.crd_plural,
namespace=self.namespace,
name=self.name,
),
)
)
def update(self, current: HTTPRoute, reference: HTTPRoute):
@@ -229,6 +218,6 @@ class HTTPRouteReconciler(KubernetesObjectReconciler):
plural=self.crd_plural,
namespace=self.namespace,
name=self.name,
body=asdict(reference),
body=reference.model_dump(mode="json"),
field_manager=FIELD_MANAGER,
)

View File

@@ -1,10 +1,9 @@
"""Kubernetes Traefik Middleware Reconciler"""
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING
from dacite.core import from_dict
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi
from pydantic import BaseModel, Field
from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler
@@ -15,39 +14,35 @@ if TYPE_CHECKING:
from authentik.outposts.controllers.kubernetes import KubernetesController
@dataclass(slots=True)
class TraefikMiddlewareSpecForwardAuth:
class TraefikMiddlewareSpecForwardAuth(BaseModel):
"""traefik middleware forwardAuth spec"""
address: str
authResponseHeadersRegex: str = field(default="")
authResponseHeadersRegex: str = Field(default="")
authResponseHeaders: list[str] = field(default_factory=list)
authResponseHeaders: list[str] = Field(default_factory=list)
trustForwardHeader: bool = field(default=True)
trustForwardHeader: bool = Field(default=True)
maxResponseBodySize: int = field(default=1024 * 1024 * 4)
maxResponseBodySize: int = Field(default=1024 * 1024 * 4)
@dataclass(slots=True)
class TraefikMiddlewareSpec:
class TraefikMiddlewareSpec(BaseModel):
"""Traefik middleware spec"""
forwardAuth: TraefikMiddlewareSpecForwardAuth
@dataclass(slots=True)
class TraefikMiddlewareMetadata:
class TraefikMiddlewareMetadata(BaseModel):
"""Traefik Middleware metadata"""
name: str
namespace: str
labels: dict = field(default_factory=dict)
labels: dict = Field(default_factory=dict)
@dataclass(slots=True)
class TraefikMiddleware:
class TraefikMiddleware(BaseModel):
"""Traefik Middleware"""
apiVersion: str
@@ -153,7 +148,7 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
version=self.crd_version,
plural=self.crd_plural,
namespace=self.namespace,
body=asdict(reference),
body=reference.model_dump(mode="json"),
field_manager=FIELD_MANAGER,
)
@@ -167,15 +162,14 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
)
def retrieve(self) -> TraefikMiddleware:
return from_dict(
TraefikMiddleware,
return TraefikMiddleware.model_validate(
self.api.get_namespaced_custom_object(
group=self.crd_group,
version=self.crd_version,
plural=self.crd_plural,
namespace=self.namespace,
name=self.name,
),
)
)
def update(self, current: TraefikMiddleware, reference: TraefikMiddleware):
@@ -185,6 +179,6 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
plural=self.crd_plural,
namespace=self.namespace,
name=self.name,
body=asdict(reference),
body=reference.model_dump(mode="json"),
field_manager=FIELD_MANAGER,
)

View File

@@ -55,10 +55,14 @@ def get_cookie_secret():
def _get_callback_url(uri: str) -> list[RedirectURI]:
return [
RedirectURI(
RedirectURIMatchingMode.STRICT,
urljoin(uri, "outpost.goauthentik.io/callback") + f"?{OUTPOST_CALLBACK_SIGNATURE}=true",
matching_mode=RedirectURIMatchingMode.STRICT,
url=urljoin(uri, "outpost.goauthentik.io/callback")
+ f"?{OUTPOST_CALLBACK_SIGNATURE}=true",
),
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url=uri + f"?{OUTPOST_CALLBACK_SIGNATURE}=true",
),
RedirectURI(RedirectURIMatchingMode.STRICT, uri + f"?{OUTPOST_CALLBACK_SIGNATURE}=true"),
]

View File

@@ -10,7 +10,6 @@ dependencies = [
"cachetools==7.0.5",
"channels==4.3.2",
"cryptography==46.0.7",
"dacite==1.9.2",
"deepmerge==2.0",
"defusedxml==0.7.1",
"django-channels-postgres",

View File

@@ -1,6 +1,5 @@
"""LDAP and Outpost e2e tests"""
from dataclasses import asdict
from time import sleep
from ldap3 import ALL, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, SUBTREE, Connection, Server
@@ -72,7 +71,7 @@ class TestProviderLDAP(ChannelsE2ETestCase):
outpost: Outpost = Outpost.objects.create(
name=generate_id(),
type=OutpostType.LDAP,
_config=asdict(OutpostConfig(log_level="debug")),
_config=OutpostConfig(log_level="debug").model_dump(mode="json"),
)
outpost.providers.add(ldap)

View File

@@ -80,7 +80,10 @@ class TestProviderOAuth2Github(SeleniumTestCase):
client_secret=self.client_secret,
client_type=ClientTypes.CONFIDENTIAL,
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/github")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/github",
)
],
authorization_flow=authorization_flow,
)
@@ -137,7 +140,10 @@ class TestProviderOAuth2Github(SeleniumTestCase):
client_secret=self.client_secret,
client_type=ClientTypes.CONFIDENTIAL,
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/github")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/github",
)
],
authorization_flow=authorization_flow,
)
@@ -210,7 +216,10 @@ class TestProviderOAuth2Github(SeleniumTestCase):
client_secret=self.client_secret,
client_type=ClientTypes.CONFIDENTIAL,
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/github")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/github",
)
],
authorization_flow=authorization_flow,
)

View File

@@ -89,7 +89,11 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
client_id=self.client_id,
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:3000/")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost:3000/"
)
],
authorization_flow=authorization_flow,
)
provider.property_mappings.set(
@@ -140,7 +144,8 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(
RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/generic_oauth"
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/generic_oauth",
)
],
authorization_flow=authorization_flow,
@@ -213,7 +218,8 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(
RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/generic_oauth"
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/generic_oauth",
)
],
authorization_flow=authorization_flow,
@@ -293,7 +299,8 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(
RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/generic_oauth"
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/generic_oauth",
)
],
)
@@ -377,7 +384,8 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(
RedirectURIMatchingMode.STRICT, "http://localhost:3000/login/generic_oauth"
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:3000/login/generic_oauth",
)
],
)

View File

@@ -74,7 +74,11 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
client_id=self.client_id,
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost:9009/"
)
],
authorization_flow=authorization_flow,
)
provider.property_mappings.set(
@@ -124,7 +128,10 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/auth/callback")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:9009/auth/callback",
)
],
authorization_flow=authorization_flow,
)
@@ -236,7 +243,10 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/auth/callback")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:9009/auth/callback",
)
],
)
provider.property_mappings.set(
@@ -340,7 +350,10 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/auth/callback")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:9009/auth/callback",
)
],
)
provider.property_mappings.set(

View File

@@ -75,7 +75,11 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
client_id=self.client_id,
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/")],
redirect_uris=[
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT, url="http://localhost:9009/"
)
],
authorization_flow=authorization_flow,
)
provider.property_mappings.set(
@@ -125,7 +129,10 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/implicit/")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:9009/implicit/",
)
],
authorization_flow=authorization_flow,
)
@@ -197,7 +204,10 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/implicit/")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:9009/implicit/",
)
],
)
provider.property_mappings.set(
@@ -287,7 +297,10 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
client_secret=self.client_secret,
signing_key=create_test_cert(),
redirect_uris=[
RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost:9009/implicit/")
RedirectURI(
matching_mode=RedirectURIMatchingMode.STRICT,
url="http://localhost:9009/implicit/",
)
],
)
provider.property_mappings.set(

View File

@@ -1,7 +1,6 @@
"""Proxy and Outpost e2e tests"""
from base64 import b64encode
from dataclasses import asdict
from json import dumps
from time import sleep
@@ -244,7 +243,9 @@ class TestProviderProxyConnect(ChannelsE2ETestCase):
outpost: Outpost = Outpost.objects.create(
name=generate_id(),
type=OutpostType.PROXY,
_config=asdict(OutpostConfig(authentik_host=self.live_server_url, log_level="debug")),
_config=OutpostConfig(
authentik_host=self.live_server_url, log_level="debug"
).model_dump(mode="json"),
)
outpost.providers.add(proxy)
outpost.build_user_permissions(outpost.user)

View File

@@ -1,6 +1,5 @@
"""Radius e2e tests"""
from dataclasses import asdict
from time import sleep
from pyrad.client import Client, Timeout
@@ -46,7 +45,7 @@ class TestProviderRadius(E2ETestCase):
outpost: Outpost = Outpost.objects.create(
name=generate_id(),
type=OutpostType.RADIUS,
_config=asdict(OutpostConfig(log_level="debug")),
_config=OutpostConfig(log_level="debug").model_dump(mode="json"),
)
outpost.providers.add(radius)

11
uv.lock generated
View File

@@ -211,7 +211,6 @@ dependencies = [
{ name = "cachetools" },
{ name = "channels" },
{ name = "cryptography" },
{ name = "dacite" },
{ name = "deepmerge" },
{ name = "defusedxml" },
{ name = "django" },
@@ -319,7 +318,6 @@ requires-dist = [
{ name = "cachetools", specifier = "==7.0.5" },
{ name = "channels", specifier = "==4.3.2" },
{ name = "cryptography", specifier = "==46.0.7" },
{ name = "dacite", specifier = "==1.9.2" },
{ name = "deepmerge", specifier = "==2.0" },
{ name = "defusedxml", specifier = "==0.7.1" },
{ name = "django", specifier = "==5.2.13" },
@@ -994,15 +992,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/eb/e65a1a359063d019913cbcb95503d86fc415e18221023b4ec92e35e3d097/cwcwidth-0.1.12-cp314-cp314t-win_amd64.whl", hash = "sha256:fdcfb9632310d2c5b9cee4e8dfbffcfe07b6ca4968d3123b6ca618603b608deb", size = 29706, upload-time = "2025-11-01T17:48:52.965Z" },
]
[[package]]
name = "dacite"
version = "1.9.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" },
]
[[package]]
name = "daphne"
version = "4.2.1"