Compare commits

...

20 Commits

Author SHA1 Message Date
Marc 'risson' Schmitt
1588e6d130 Merge branch 'main' into lib-typing 2025-11-03 17:55:00 +01:00
Marc 'risson' Schmitt
fc0366b3f4 Merge branch 'main' into lib-typing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-23 15:19:07 +02:00
Marc 'risson' Schmitt
db849599f5 lint
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 17:00:03 +02:00
Marc 'risson' Schmitt
267f9d9905 fix more tests
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 16:57:26 +02:00
Marc 'risson' Schmitt
2e62d7cb14 start fixing tests
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 16:54:39 +02:00
Marc 'risson' Schmitt
c4adff1b26 more
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 16:27:08 +02:00
Marc 'risson' Schmitt
adcad1350d more
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 16:03:37 +02:00
Marc 'risson' Schmitt
94f64882ab more
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 14:30:37 +02:00
Marc 'risson' Schmitt
e0b592c035 more
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 14:12:27 +02:00
Marc 'risson' Schmitt
02ae7eada7 more
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-22 14:04:36 +02:00
Marc 'risson' Schmitt
c12749e3e9 more
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-21 18:56:12 +02:00
Marc 'risson' Schmitt
171e83b088 lib: typing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-21 15:17:58 +02:00
Marc 'risson' Schmitt
219666c32d Merge branch 'main' into lib-typing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-10-21 15:13:42 +02:00
Marc 'risson' Schmitt
e9ec83fd03 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-25 13:53:03 +02:00
Marc 'risson' Schmitt
123cca34a1 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-25 13:42:02 +02:00
Marc 'risson' Schmitt
3a2559b115 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-25 13:38:55 +02:00
Marc 'risson' Schmitt
a9b50c8c77 lib: typing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-24 18:52:58 +02:00
Marc 'risson' Schmitt
b81d415faf fix paramspec
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-24 18:10:53 +02:00
Marc 'risson' Schmitt
58d5d37953 also update usage of currenttask
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-24 16:52:36 +02:00
Marc 'risson' Schmitt
cc71bb6e74 packages/django-dramatiq-postgres: typing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-09-24 16:35:18 +02:00
52 changed files with 415 additions and 380 deletions

View File

@@ -2,23 +2,27 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from django.apps import apps
from authentik.blueprints.apps import ManagedAppConfig
from authentik.blueprints.models import BlueprintInstance
P = ParamSpec("P")
R = TypeVar("R")
def apply_blueprint(*files: str):
def apply_blueprint(*files: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Apply blueprint before test"""
from authentik.blueprints.v1.importer import Importer
def wrapper_outer(func: Callable):
def wrapper_outer(func: Callable[P, R]) -> Callable[P, R]:
"""Apply blueprint before test"""
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
for file in files:
content = BlueprintInstance(path=file).retrieve()
Importer.from_string(content).apply()

View File

@@ -3,12 +3,11 @@
from types import CodeType
from typing import Any
from django.db.models import Model
from django.http import HttpRequest
from prometheus_client import Histogram
from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import User
from authentik.core.models import PropertyMapping, User
from authentik.events.models import Event, EventAction
from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.policies.types import PolicyRequest
@@ -23,13 +22,13 @@ PROPERTY_MAPPING_TIME = Histogram(
class PropertyMappingEvaluator(BaseEvaluator):
"""Custom Evaluator that adds some different context variables."""
dry_run: bool
model: Model
dry_run: bool | None
model: PropertyMapping
_compiled: CodeType | None = None
def __init__(
self,
model: Model,
model: PropertyMapping,
user: User | None = None,
request: HttpRequest | None = None,
dry_run: bool | None = False,

View File

@@ -1,13 +1,9 @@
"""GoogleWorkspaceProviderGroup API Views"""
from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.users import PartialGroupSerializer
from authentik.core.api.utils import ModelSerializer
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
@@ -16,7 +12,6 @@ class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
group_obj = PartialGroupSerializer(source="group", read_only=True)
class Meta:
model = GoogleWorkspaceProviderGroup
fields = [
"id",
@@ -29,15 +24,7 @@ class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
extra_kwargs = {"attributes": {"read_only": True}}
class GoogleWorkspaceProviderGroupViewSet(
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
class GoogleWorkspaceProviderGroupViewSet(OutgoingSyncConnectionViewSet):
"""GoogleWorkspaceProviderGroup Viewset"""
queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group")

View File

@@ -1,16 +1,13 @@
"""Google Provider API Views"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import (
google_workspace_sync,
google_workspace_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderViewSet
class GoogleWorkspaceProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
@@ -44,18 +41,16 @@ class GoogleWorkspaceProviderSerializer(EnterpriseRequiredMixin, ProviderSeriali
extra_kwargs = {}
class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderViewSet):
"""GoogleWorkspaceProvider Viewset"""
queryset = GoogleWorkspaceProvider.objects.all()
serializer_class = GoogleWorkspaceProviderSerializer
filterset_fields = [
"name",
"exclude_users_service_account",
filterset_fields = OutgoingSyncProviderViewSet.filterset_fields + [
"delegated_subject",
]
search_fields = OutgoingSyncProviderViewSet.search_fields + [
"delegated_subject",
"filter_group",
]
search_fields = ["name"]
ordering = ["name"]
sync_task = google_workspace_sync
sync_objects_task = google_workspace_sync_objects

View File

@@ -1,13 +1,9 @@
"""GoogleWorkspaceProviderUser API Views"""
from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.groups import PartialUserSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
@@ -16,7 +12,6 @@ class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
user_obj = PartialUserSerializer(source="user", read_only=True)
class Meta:
model = GoogleWorkspaceProviderUser
fields = [
"id",
@@ -29,15 +24,7 @@ class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
extra_kwargs = {"attributes": {"read_only": True}}
class GoogleWorkspaceProviderUserViewSet(
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
class GoogleWorkspaceProviderUserViewSet(OutgoingSyncConnectionViewSet):
"""GoogleWorkspaceProviderUser Viewset"""
queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user")

View File

@@ -12,7 +12,6 @@ from google.oauth2.service_account import Credentials
from rest_framework.serializers import Serializer
from authentik.core.models import (
BackchannelProvider,
Group,
PropertyMapping,
User,
@@ -84,7 +83,7 @@ class GoogleWorkspaceProviderGroup(SerializerModel):
return f"Google Workspace Provider Group {self.group_id} to {self.provider_id}"
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
class GoogleWorkspaceProvider(OutgoingSyncProvider):
"""Sync users from authentik into Google Workspace."""
delegated_subject = models.EmailField()

View File

@@ -1,13 +1,9 @@
"""MicrosoftEntraProviderGroup API Views"""
from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.users import PartialGroupSerializer
from authentik.core.api.utils import ModelSerializer
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
@@ -16,7 +12,6 @@ class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
group_obj = PartialGroupSerializer(source="group", read_only=True)
class Meta:
model = MicrosoftEntraProviderGroup
fields = [
"id",
@@ -29,15 +24,7 @@ class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
extra_kwargs = {"attributes": {"read_only": True}}
class MicrosoftEntraProviderGroupViewSet(
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
class MicrosoftEntraProviderGroupViewSet(OutgoingSyncConnectionViewSet):
"""MicrosoftEntraProviderGroup Viewset"""
queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group")

View File

@@ -1,16 +1,13 @@
"""Microsoft Provider API Views"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import (
microsoft_entra_sync,
microsoft_entra_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderViewSet
class MicrosoftEntraProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
@@ -43,17 +40,10 @@ class MicrosoftEntraProviderSerializer(EnterpriseRequiredMixin, ProviderSerializ
extra_kwargs = {}
class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
class MicrosoftEntraProviderViewSet(OutgoingSyncProviderViewSet):
"""MicrosoftEntraProvider Viewset"""
queryset = MicrosoftEntraProvider.objects.all()
serializer_class = MicrosoftEntraProviderSerializer
filterset_fields = [
"name",
"exclude_users_service_account",
"filter_group",
]
search_fields = ["name"]
ordering = ["name"]
sync_task = microsoft_entra_sync
sync_objects_task = microsoft_entra_sync_objects

View File

@@ -1,13 +1,9 @@
"""MicrosoftEntraProviderUser API Views"""
from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.groups import PartialUserSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
class MicrosoftEntraProviderUserSerializer(ModelSerializer):
@@ -16,7 +12,6 @@ class MicrosoftEntraProviderUserSerializer(ModelSerializer):
user_obj = PartialUserSerializer(source="user", read_only=True)
class Meta:
model = MicrosoftEntraProviderUser
fields = [
"id",
@@ -29,15 +24,7 @@ class MicrosoftEntraProviderUserSerializer(ModelSerializer):
extra_kwargs = {"attributes": {"read_only": True}}
class MicrosoftEntraProviderUserViewSet(
OutgoingSyncConnectionCreateMixin,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
class MicrosoftEntraProviderUserViewSet(OutgoingSyncConnectionViewSet):
"""MicrosoftEntraProviderUser Viewset"""
queryset = MicrosoftEntraProviderUser.objects.all().select_related("user")

View File

@@ -12,7 +12,6 @@ from dramatiq.actor import Actor
from rest_framework.serializers import Serializer
from authentik.core.models import (
BackchannelProvider,
Group,
PropertyMapping,
User,
@@ -75,7 +74,7 @@ class MicrosoftEntraProviderGroup(SerializerModel):
return f"Microsoft Entra Provider Group {self.group_id} to {self.provider_id}"
class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
class MicrosoftEntraProvider(OutgoingSyncProvider):
"""Sync users from authentik into Microsoft Entra."""
client_id = models.TextField()

View File

@@ -237,7 +237,7 @@ class Event(SerializerModel, ExpiringModel):
self.save()
return self
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
if self._state.adding:
LOGGER.info(
"Created Event",

View File

@@ -3,7 +3,7 @@
from base64 import b64encode
from functools import cache as funccache
from hashlib import md5, sha256
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from urllib.parse import urlencode
from django.core.cache import cache
@@ -27,7 +27,7 @@ CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available"
GRAVATAR_STATUS_TTL_SECONDS = 60 * 60 * 8 # 8 Hours
SVG_XML_NS = "http://www.w3.org/2000/svg"
SVG_NS_MAP = {None: SVG_XML_NS}
SVG_NS_MAP: dict[str, str] = cast(dict[str, str], {None: SVG_XML_NS})
# Match fonts used in web UI
SVG_FONTS = [
"'RedHatText'",
@@ -39,7 +39,7 @@ SVG_FONTS = [
]
def avatar_mode_none(user: "User", mode: str) -> str | None:
def avatar_mode_none(user: "User", mode: str) -> str:
"""No avatar"""
return DEFAULT_AVATAR
@@ -62,7 +62,7 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
full_key = CACHE_KEY_GRAVATAR + mail_hash
if cache.has_key(full_key):
cache.touch(full_key)
return cache.get(full_key)
return cast(str | None, cache.get(full_key))
try:
# Since we specify a default of 404, do a HEAD request
@@ -129,16 +129,16 @@ def generate_avatar_from_name(
bg_hex, text_hex = generate_colors(name)
half_size = size // 2
shape = "circle" if rounded else "rect"
shape_type = "circle" if rounded else "rect"
font_weight = "600" if bold else "400"
root_element: Element = Element(f"{{{SVG_XML_NS}}}svg", nsmap=SVG_NS_MAP)
root_element = Element(f"{{{SVG_XML_NS}}}svg", nsmap=SVG_NS_MAP)
root_element.attrib["width"] = f"{size}px"
root_element.attrib["height"] = f"{size}px"
root_element.attrib["viewBox"] = f"0 0 {size} {size}"
root_element.attrib["version"] = "1.1"
shape = SubElement(root_element, f"{{{SVG_XML_NS}}}{shape}", nsmap=SVG_NS_MAP)
shape = SubElement(root_element, f"{{{SVG_XML_NS}}}{shape_type}", nsmap=SVG_NS_MAP)
shape.attrib["fill"] = f"#{bg_hex}"
shape.attrib["cx"] = f"{half_size}"
shape.attrib["cy"] = f"{half_size}"
@@ -150,7 +150,7 @@ def generate_avatar_from_name(
text.attrib["x"] = "50%"
text.attrib["y"] = "50%"
text.attrib["style"] = (
f"color: #{text_hex}; " "line-height: 1; " f"font-family: {','.join(SVG_FONTS)}; "
f"color: #{text_hex}; line-height: 1; font-family: {','.join(SVG_FONTS)}; "
)
text.attrib["fill"] = f"#{text_hex}"
text.attrib["alignment-baseline"] = "middle"
@@ -197,7 +197,7 @@ def get_avatar(user: "User", request: HttpRequest | None = None) -> str:
}
tenant = None
if request:
tenant = request.tenant
tenant = request.tenant # type: ignore[attr-defined]
else:
tenant = get_current_tenant()
modes: str = tenant.avatars

View File

@@ -1,3 +1,5 @@
from typing import Any
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
@@ -5,11 +7,11 @@ from authentik.lib.config import CONFIG
LOGGER = get_logger()
def start_debug_server(**kwargs) -> bool:
def start_debug_server(**kwargs: Any) -> bool:
"""Attempt to start a debugpy server in the current process.
Returns true if the server was started successfully, otherwise false"""
if not CONFIG.get_bool("debug") and not CONFIG.get_bool("debugger"):
return
return False
try:
import debugpy
except ImportError:

View File

@@ -13,10 +13,9 @@ from django.core.exceptions import FieldError
from django.http import HttpRequest
from django.utils.text import slugify
from django.utils.timezone import now
from guardian.shortcuts import get_anonymous_user
from guardian.utils import get_anonymous_user
from rest_framework.serializers import ValidationError
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
from structlog.stdlib import get_logger
from authentik.core.models import User
@@ -55,7 +54,7 @@ class BaseEvaluator:
# Filename used for exec
_filename: str
def __init__(self, filename: str | None = None):
def __init__(self, filename: str | None = None) -> None:
self._filename = filename if filename else "BaseEvaluator"
# update website/docs/expressions/_objects.md
# update website/docs/expressions/_functions.md
@@ -133,12 +132,12 @@ class BaseEvaluator:
return re.sub(regex, repl, value)
@staticmethod
def expr_is_group_member(user: User, **group_filters) -> bool:
def expr_is_group_member(user: User, **group_filters: Any) -> bool:
"""Check if `user` is member of group with name `group_name`"""
return user.all_groups().filter(**group_filters).exists()
@staticmethod
def expr_user_by(**filters) -> User | None:
def expr_user_by(**filters: Any) -> User | None:
"""Get user by filters"""
try:
users = User.objects.filter(**filters)
@@ -160,7 +159,7 @@ class BaseEvaluator:
return False
return len(list(user_devices)) > 0
def expr_event_create(self, action: str, **kwargs):
def expr_event_create(self, action: str, **kwargs: Any) -> None:
"""Create event with supplied data and try to extract as much relevant data
from the context"""
context = self._context.copy()
@@ -181,7 +180,7 @@ class BaseEvaluator:
return
event.save()
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
def expr_func_call_policy(self, name: str, **kwargs: Any) -> PolicyResult:
"""Call policy by name, with current request"""
policy = Policy.objects.filter(name=name).select_subclasses().first()
if not policy:
@@ -214,10 +213,10 @@ class BaseEvaluator:
provider=provider,
user=user,
expires=now() + timedelta_from_string(validity),
scope=scopes,
auth_time=now(),
session=session,
)
access_token.scope = scopes
access_token.id_token = IDToken.new(provider, access_token, request)
access_token.save()
return access_token.token
@@ -229,7 +228,7 @@ class BaseEvaluator:
body: str | None = None,
stage: "EmailStage | None" = None,
template: str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> bool:
"""Send an email using authentik's email system
@@ -316,7 +315,6 @@ class BaseEvaluator:
If any exception is raised during execution, it is raised.
The result is returned without any type-checking."""
with start_span(op="authentik.lib.evaluator.evaluate") as span:
span: Span
span.description = self._filename
span.set_data("expression", expression_source)
try:
@@ -343,7 +341,7 @@ class BaseEvaluator:
raise exc
return result
def handle_error(self, exc: Exception, expression_source: str): # pragma: no cover
def handle_error(self, exc: Exception, expression_source: str) -> None: # pragma: no cover
"""Exception Handler"""
LOGGER.warning("Expression error", exc=exc)

View File

@@ -4,20 +4,20 @@ import string
from random import SystemRandom
def generate_code_fixed_length(length=9) -> str:
def generate_code_fixed_length(length: int = 9) -> str:
"""Generate a numeric code"""
rand = SystemRandom()
num = rand.randrange(1, 10**length)
return str(num).zfill(length)
def generate_id(length=40) -> str:
def generate_id(length: int = 40) -> str:
"""Generate a random client ID"""
rand = SystemRandom()
return "".join(rand.choice(string.ascii_letters + string.digits) for x in range(length))
def generate_key(length=128) -> str:
def generate_key(length: int = 128) -> str:
"""Generate a suitable client secret"""
rand = SystemRandom()
return "".join(

View File

@@ -3,9 +3,11 @@
import logging
from logging import Logger
from os import getpid
from typing import Any
import structlog
from django.db import connection
from structlog.typing import EventDict
from authentik.lib.config import CONFIG
@@ -19,9 +21,9 @@ LOG_PRE_CHAIN = [
]
def get_log_level():
def get_log_level() -> str:
"""Get log level, clamp trace to debug"""
level = CONFIG.get("log_level").upper()
level: str = CONFIG.get("log_level").upper()
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level
@@ -31,7 +33,7 @@ def get_log_level():
return level
def structlog_configure():
def structlog_configure() -> None:
"""Configure structlog itself"""
structlog.configure_once(
processors=[
@@ -56,11 +58,11 @@ def structlog_configure():
)
def get_logger_config():
def get_logger_config() -> dict[str, Any]:
"""Configure python stdlib's logging"""
debug = CONFIG.get_bool("debug")
global_level = get_log_level()
base_config = {
base_config: dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
@@ -121,13 +123,13 @@ def get_logger_config():
return base_config
def add_process_id(logger: Logger, method_name: str, event_dict):
def add_process_id(logger: Logger, method_name: str, event_dict: EventDict) -> EventDict:
"""Add the current process ID"""
event_dict["pid"] = getpid()
return event_dict
def add_tenant_information(logger: Logger, method_name: str, event_dict):
def add_tenant_information(logger: Logger, method_name: str, event_dict: EventDict) -> EventDict:
"""Add the current tenant"""
tenant = getattr(connection, "tenant", None)
schema_name = getattr(connection, "schema_name", None)

View File

@@ -1,6 +1,6 @@
"""merge utils"""
from deepmerge import Merger
from deepmerge import Merger # type: ignore[attr-defined]
MERGE_LIST_UNIQUE = Merger(
[(list, ["append_unique"]), (dict, ["merge"]), (set, ["union"])], ["override"], ["override"]

View File

@@ -1,16 +1,18 @@
"""Migration helpers"""
from collections.abc import Iterable
from collections.abc import Callable, Collection, Generator
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def fallback_names(app: str, model: str, field: str):
def fallback_names(
app: str, model: str, field: str
) -> Callable[[Apps, BaseDatabaseSchemaEditor], None]:
"""Factory function that checks all instances of `app`.`model` instance's `field`
to prevent any duplicates"""
def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor) -> None:
db_alias = schema_editor.connection.alias
klass = apps.get_model(app, model)
@@ -35,7 +37,7 @@ def fallback_names(app: str, model: str, field: str):
return migrator
def progress_bar(iterable: Iterable):
def progress_bar[R](iterable: Collection[R]) -> Generator[R]:
"""Call in a loop to create terminal progress bar
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console"""
@@ -50,7 +52,7 @@ def progress_bar(iterable: Iterable):
if total < 1:
return
def print_progress_bar(iteration):
def print_progress_bar(iteration: int) -> None:
"""Progress Bar Printing Function"""
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filled_length = int(length * iteration // total)

View File

@@ -1,7 +1,7 @@
"""authentik sentry integration"""
from asyncio.exceptions import CancelledError
from typing import Any
from typing import TYPE_CHECKING, Any
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
@@ -34,6 +34,9 @@ from authentik.lib.utils.reflection import get_env
LOGGER = get_logger()
_root_path = CONFIG.get("web.path", "/")
if TYPE_CHECKING:
from sentry_sdk._types import Event
class SentryIgnoredException(Exception):
"""Base Class for all errors that are suppressed, and not sent to sentry."""
@@ -79,10 +82,11 @@ class SentryTransport(HttpTransport):
def __init__(self, options: dict[str, Any]) -> None:
super().__init__(options)
assert self.parsed_dsn is not None # nosec
self._auth = self.parsed_dsn.to_auth(authentik_user_agent())
def sentry_init(**sentry_init_kwargs):
def sentry_init(**sentry_init_kwargs: Any) -> None:
"""Configure sentry SDK"""
sentry_env = CONFIG.get("error_reporting.environment", "customer")
kwargs = {
@@ -116,7 +120,7 @@ def sentry_init(**sentry_init_kwargs):
set_tag("authentik.component", "backend")
def traces_sampler(sampling_context: dict) -> float:
def traces_sampler(sampling_context: dict[str, Any]) -> float:
"""Custom sampler to ignore certain routes"""
path = sampling_context.get("asgi_scope", {}).get("path", "")
_type = sampling_context.get("asgi_scope", {}).get("type", "")
@@ -135,7 +139,7 @@ def should_ignore_exception(exc: Exception) -> bool:
return isinstance(exc, ignored_classes)
def before_send(event: dict, hint: dict) -> dict | None:
def before_send(event: "Event", hint: dict[str, Any]) -> "Event | None":
"""Check if error is database error, and ignore if so"""
exc_value = None
if "exc_info" in hint:
@@ -157,7 +161,7 @@ def before_send(event: dict, hint: dict) -> dict | None:
return event
def get_http_meta():
def get_http_meta() -> dict[str, Any]:
"""Get sentry-related meta key-values"""
scope = get_current_scope()
meta = {

View File

@@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import Any
from django.db.models import QuerySet
from django.http import HttpRequest
@@ -20,7 +21,7 @@ class PropertyMappingManager:
_evaluators: list[PropertyMappingEvaluator]
globals: dict
globals: dict[str, Any]
__has_compiled: bool
@@ -40,7 +41,7 @@ class PropertyMappingManager:
self.globals = {}
self.__has_compiled = False
def compile(self):
def compile(self) -> None:
self._evaluators = []
for mapping in self.query_set:
if not isinstance(mapping, self.mapping_subclass):
@@ -58,8 +59,8 @@ class PropertyMappingManager:
user: User | None,
request: HttpRequest | None,
return_mapping: bool = False,
**kwargs,
) -> Generator[tuple[dict, PropertyMapping]]:
**kwargs: Any,
) -> Generator[tuple[Any, PropertyMapping]]:
"""Iterate over all mappings that were pre-compiled and
execute all of them with the given context"""
if not self.__has_compiled:

View File

@@ -1,15 +1,26 @@
from typing import Any
from django.db.models import Model
from dramatiq.actor import Actor
from dramatiq.results.errors import ResultFailure
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ChoiceField
from rest_framework.mixins import (
CreateModelMixin,
DestroyModelMixin,
ListModelMixin,
RetrieveModelMixin,
)
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User
from authentik.events.logs import LogEventSerializer
from authentik.lib.models import SerializerModel
from authentik.lib.sync.api import SyncStatusSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class
@@ -36,11 +47,19 @@ class SyncObjectResultSerializer(PassiveSerializer):
messages = LogEventSerializer(many=True, read_only=True)
class OutgoingSyncProviderStatusMixin:
class OutgoingSyncProviderViewSet(UsedByMixin, ModelViewSet[OutgoingSyncProvider]):
"""Common API Endpoints for Outgoing sync providers"""
sync_task: Actor
sync_objects_task: Actor
sync_task: Actor[[int, Actor[[str, int, int, bool], None]], None]
sync_objects_task: Actor[[str, int, int, bool, dict[str, Any | None]], None]
filterset_fields = [
"name",
"exclude_users_service_account",
"filter_group",
]
search_fields = ["name"]
ordering = ["name"]
@extend_schema(responses={200: SyncStatusSerializer()})
@action(
@@ -68,20 +87,20 @@ class OutgoingSyncProviderStatusMixin:
if not sync_schedule:
return Response(SyncStatusSerializer(status).data)
last_task: Task = (
last_task = (
sync_schedule.tasks.filter(state__in=(TaskStatus.DONE, TaskStatus.REJECTED))
.order_by("-mtime")
.first()
)
last_successful_task: Task = (
last_successful_task = (
sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO))
.order_by("-mtime")
.first()
)
if last_task:
if last_task is not None:
status["last_sync_status"] = last_task.aggregated_status
if last_successful_task:
if last_successful_task is not None:
status["last_successful_sync"] = last_successful_task.mtime
return Response(SyncStatusSerializer(status).data)
@@ -111,7 +130,7 @@ class OutgoingSyncProviderStatusMixin:
"page": 1,
"provider_pk": provider.pk,
"override_dry_run": params.validated_data["override_dry_run"],
"pk": pk,
"filter": {"pk": pk},
},
retries=0,
rel_obj=provider,
@@ -126,13 +145,20 @@ class OutgoingSyncProviderStatusMixin:
return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data)
class OutgoingSyncConnectionCreateMixin:
"""Mixin for connection objects that fetches remote data upon creation"""
def perform_create(self, serializer: ModelSerializer):
class OutgoingSyncConnectionViewSet(
CreateModelMixin,
RetrieveModelMixin,
DestroyModelMixin,
ListModelMixin,
UsedByMixin,
GenericViewSet[SerializerModel],
):
def perform_create(self, serializer: ModelSerializer) -> None: # type: ignore[override]
super().perform_create(serializer)
try:
instance = serializer.instance
if instance is None:
return
client = instance.provider.client_for_model(instance.__class__)
client.update_single_attribute(instance)
instance.save()

View File

@@ -1,7 +1,8 @@
"""Basic outgoing sync Client"""
from collections.abc import MutableMapping
from enum import StrEnum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast
from deepmerge import always_merger
from django.db import DatabaseError
@@ -18,11 +19,11 @@ from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSy
if TYPE_CHECKING:
from django.db.models import Model
from authentik.core.models import Group, User
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
class Direction(StrEnum):
add = "add"
remove = "remove"
@@ -36,7 +37,10 @@ SAFE_METHODS = [
class BaseOutgoingSyncClient[
TModel: "Model", TConnection: "Model", TSchema: dict, TProvider: "OutgoingSyncProvider"
TModel: "User | Group",
TConnection: "Model",
TSchema: MutableMapping[Any, Any],
TProvider: "OutgoingSyncProvider",
]:
"""Basic Outgoing sync client Client"""
@@ -55,14 +59,17 @@ class BaseOutgoingSyncClient[
"""Create object in remote destination"""
raise NotImplementedError()
def update(self, obj: TModel, connection: TConnection):
def update(self, obj: TModel, connection: TConnection) -> None:
"""Update object in remote destination"""
raise NotImplementedError()
def write(self, obj: TModel) -> tuple[TConnection, bool]:
def update_group(self, group: "Group", action: Direction, users_set: list[Any]) -> None:
raise NotImplementedError()
def write(self, obj: TModel) -> tuple[TConnection | None, bool]:
"""Write object to destination. Uses self.create and self.update, but
can be overwritten for further logic"""
connection = self.connection_type.objects.filter(
connection = self.connection_type.objects.filter( # type: ignore[attr-defined]
provider=self.provider, **{self.connection_type_query: obj}
).first()
try:
@@ -82,13 +89,13 @@ class BaseOutgoingSyncClient[
connection.delete()
return None, False
def delete(self, obj: TModel):
def delete(self, obj: TModel) -> None:
"""Delete object from destination"""
raise NotImplementedError()
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults) -> TSchema:
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults: Any) -> TSchema:
"""Convert object to destination schema"""
raw_final_object = {}
raw_final_object: dict[Any, Any] = {}
try:
eval_kwargs = {
"request": None,
@@ -97,7 +104,7 @@ class BaseOutgoingSyncClient[
obj._meta.model_name: obj,
}
eval_kwargs.setdefault("user", None)
for value in self.mapper.iter_eval(**eval_kwargs):
for value in self.mapper.iter_eval(**eval_kwargs): # type: ignore[arg-type, misc]
always_merger.merge(raw_final_object, value)
except ControlFlowException as exc:
raise exc from exc
@@ -113,16 +120,16 @@ class BaseOutgoingSyncClient[
raise StopSync(ValueError("No mappings configured"), obj)
for key, value in defaults.items():
raw_final_object.setdefault(key, value)
return raw_final_object
return cast(TSchema, raw_final_object)
def discover(self):
def discover(self) -> None:
"""Optional method. Can be used to implement a "discovery" where
upon creation of this provider, this function will be called and can
pre-link any users/groups in the remote system with the respective
object in authentik based on a common identifier"""
raise NotImplementedError()
def update_single_attribute(self, connection: TConnection):
def update_single_attribute(self, connection: TConnection) -> None:
"""Update connection attributes on a connection object, when the connection
is manually created"""
raise NotImplementedError

View File

@@ -1,3 +1,5 @@
from typing import Any
from authentik.lib.sentry import SentryIgnoredException
@@ -24,16 +26,16 @@ class BadRequestSyncException(BaseSyncException):
class DryRunRejected(BaseSyncException):
"""When dry_run is enabled and a provider dropped a mutating request"""
def __init__(self, url: str, method: str, body: dict):
def __init__(self, url: str, method: str, body: dict[Any, Any]) -> None:
super().__init__()
self.url = url
self.method = method
self.body = body
def __repr__(self):
def __repr__(self) -> str:
return self.__str__()
def __str__(self):
def __str__(self) -> str:
return f"Dry-run rejected request: {self.method} {self.url}"

View File

@@ -4,11 +4,11 @@ import pglock
from django.core.paginator import Paginator
from django.core.validators import MinValueValidator
from django.db import connection, models
from django.db.models import Model, QuerySet, TextChoices
from django.db.models import QuerySet, TextChoices
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from authentik.core.models import Group, User
from authentik.core.models import BackchannelProvider, Group, User
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.utils.time import fqdn_rand, timedelta_from_string, timedelta_string_validator
from authentik.tasks.schedules.common import ScheduleSpec
@@ -24,7 +24,7 @@ class OutgoingSyncDeleteAction(TextChoices):
SUSPEND = "suspend"
class OutgoingSyncProvider(ScheduledModel, Model):
class OutgoingSyncProvider(ScheduledModel, BackchannelProvider):
"""Base abstract models for providers implementing outgoing sync"""
sync_page_size = models.PositiveIntegerField(
@@ -56,7 +56,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
raise NotImplementedError
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
def get_paginator[T: User | Group](self, type: type[T]) -> "Paginator[T]":
return Paginator(self.get_object_qs(type), self.sync_page_size)
def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
@@ -74,13 +74,15 @@ class OutgoingSyncProvider(ScheduledModel, Model):
def sync_lock(self) -> pglock.advisory:
"""Postgres lock for syncing to prevent multiple parallel syncs happening"""
return pglock.advisory(
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}",
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}", # type: ignore[attr-defined]
timeout=0,
side_effect=pglock.Return,
)
@property
def sync_actor(self) -> Actor:
def sync_actor(
self,
) -> Actor[[int, Actor[[str, int, int, bool, dict[str, Any] | None], None]], None]:
raise NotImplementedError
@property
@@ -94,6 +96,6 @@ class OutgoingSyncProvider(ScheduledModel, Model):
"time_limit": self.get_sync_time_limit_ms(),
},
send_on_save=True,
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
crontab=f"{fqdn_rand(str(self.pk))} */4 * * *",
),
]

View File

@@ -1,4 +1,7 @@
from django.db.models import Model
from collections.abc import Iterable
from typing import Any, TypeVar
from uuid import UUID
from django.db.models.signals import m2m_changed, post_save, pre_delete
from dramatiq.actor import Actor
@@ -7,22 +10,24 @@ from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
ModelT = TypeVar("ModelT", bound=User | Group)
def register_signals(
provider_type: type[OutgoingSyncProvider],
task_sync_direct_dispatch: Actor[[str, str | int, str], None],
task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
):
task_sync_direct_dispatch: Actor[[str, Any, str], None],
task_sync_m2m_dispatch: Actor[[Any, str, list[Any], bool], None],
) -> None:
"""Register sync signals"""
uid = class_to_path(provider_type)
def model_post_save(
sender: type[Model],
instance: User | Group,
sender: type[ModelT],
instance: ModelT,
created: bool,
update_fields: list[str] | None = None,
**_,
):
update_fields: Iterable[str] | None = None,
**_: Any,
) -> None:
"""Post save handler"""
# Special case for user object; don't start sync task when we've only updated `last_login`
# This primarily happens during user login
@@ -37,7 +42,7 @@ def register_signals(
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
def model_pre_delete(sender: type[ModelT], instance: ModelT, **_: Any) -> None:
"""Pre-delete handler"""
task_sync_direct_dispatch.send(
class_to_path(instance.__class__),
@@ -49,8 +54,13 @@ def register_signals(
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
def model_m2m_changed(
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
):
sender: type[ModelT],
instance: ModelT,
action: str,
pk_set: set[int | UUID],
reverse: bool,
**_: Any,
) -> None:
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return

View File

@@ -1,9 +1,11 @@
from typing import Any, cast
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.db.models import Model, Q
from dramatiq.actor import Actor
from dramatiq.composition import group
from dramatiq.errors import Retry
from dramatiq.message import Message
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException
@@ -38,11 +40,11 @@ class SyncTasks:
self,
current_task: Task,
provider: OutgoingSyncProvider,
sync_objects: Actor[[str, int, int, bool], None],
paginator: Paginator,
sync_objects: Actor[[str, int, int, bool, dict[str, Any] | None], None],
paginator: "Paginator[User | Group]",
object_type: type[User | Group],
**options,
):
**options: Any,
) -> list[Message[None]]:
tasks = []
time_limit = timedelta_from_string(provider.sync_page_timeout).total_seconds() * 1000
for page in paginator.page_range:
@@ -60,14 +62,14 @@ class SyncTasks:
def sync(
self,
provider_pk: int,
sync_objects: Actor[[str, int, int, bool], None],
):
sync_objects: Actor[[str, int, int, bool, dict[str, Any] | None], None],
) -> None:
task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
)
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
provider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
@@ -82,7 +84,7 @@ class SyncTasks:
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
return
try:
users_tasks = group(
users_tasks = group( # type: ignore[no-untyped-call]
self.sync_paginator(
current_task=task,
provider=provider,
@@ -91,7 +93,7 @@ class SyncTasks:
object_type=User,
)
)
group_tasks = group(
group_tasks = group( # type: ignore[no-untyped-call]
self.sync_paginator(
current_task=task,
provider=provider,
@@ -100,12 +102,12 @@ class SyncTasks:
object_type=Group,
)
)
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) # type: ignore[no-untyped-call]
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) # type: ignore[no-untyped-call]
except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc)
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
raise Retry() from exc
raise Retry() from exc # type: ignore[no-untyped-call]
except StopSync as exc:
task.error(exc)
return
@@ -115,11 +117,11 @@ class SyncTasks:
object_type: str,
page: int,
provider_pk: int,
override_dry_run=False,
**filter,
):
override_dry_run: bool = False,
filter: dict[str, Any] | None = None,
) -> None:
task = CurrentTask.get_task()
_object_type: type[Model] = path_to_class(object_type)
_object_type: type[User | Group] = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
@@ -140,6 +142,8 @@ class SyncTasks:
client = provider.client_for_model(_object_type)
except TransientSyncException:
return
if filter is None:
filter = {}
paginator = Paginator(
provider.get_object_qs(_object_type).filter(**filter),
provider.sync_page_size,
@@ -150,7 +154,6 @@ class SyncTasks:
self.logger.debug("starting sync for page", page=page)
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
for obj in paginator.page(page).object_list:
obj: Model
try:
client.write(obj)
except SkipObjectException:
@@ -189,11 +192,11 @@ class SyncTasks:
def sync_signal_direct_dispatch(
self,
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
task_sync_signal_direct: Actor[[str, Any, int, str], None],
model: str,
pk: str | int,
pk: Any,
raw_op: str,
):
) -> None:
model_class: type[Model] = path_to_class(model)
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
@@ -207,19 +210,19 @@ class SyncTasks:
def sync_signal_direct(
self,
model: str,
pk: str | int,
pk: Any,
provider_pk: int,
raw_op: str,
):
) -> None:
task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
model_class: type[Model] = path_to_class(model)
model_class: type[User | Group] = path_to_class(model)
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
provider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
@@ -244,7 +247,7 @@ class SyncTasks:
if operation == Direction.remove:
client.delete(instance)
except TransientSyncException as exc:
raise Retry() from exc
raise Retry() from exc # type: ignore[no-untyped-call]
except SkipObjectException:
return
except DryRunRejected as exc:
@@ -254,12 +257,12 @@ class SyncTasks:
def sync_signal_m2m_dispatch(
self,
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
instance_pk: str,
task_sync_signal_m2m: Actor[[Any, int, str, list[Any]], None],
instance_pk: Any,
action: str,
pk_set: list[int],
pk_set: list[Any],
reverse: bool,
):
) -> None:
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
@@ -281,11 +284,11 @@ class SyncTasks:
def sync_signal_m2m(
self,
group_pk: str,
group_pk: Any,
provider_pk: int,
action: str,
pk_set: list[int],
):
pk_set: list[Any],
) -> None:
task = CurrentTask.get_task()
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
@@ -293,7 +296,7 @@ class SyncTasks:
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
provider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
@@ -302,7 +305,7 @@ class SyncTasks:
return
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
queryset = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
@@ -315,9 +318,9 @@ class SyncTasks:
operation = Direction.add
if action == "post_remove":
operation = Direction.remove
client.update_group(group, operation, pk_set)
client.update_group(group, cast(Direction, operation), pk_set)
except TransientSyncException as exc:
raise Retry() from exc
raise Retry() from exc # type: ignore[no-untyped-call]
except SkipObjectException:
return
except DryRunRejected as exc:

View File

@@ -1,6 +1,6 @@
"""Test Evaluator base functions"""
from unittest.mock import patch
from unittest.mock import NonCallableMock, patch
from django.test import RequestFactory, TestCase
from django.urls import reverse
@@ -17,27 +17,27 @@ from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
class TestEvaluator(TestCase):
"""Test Evaluator base functions"""
def test_expr_regex_match(self):
def test_expr_regex_match(self) -> None:
"""Test expr_regex_match"""
self.assertFalse(BaseEvaluator.expr_regex_match("foo", "bar"))
self.assertTrue(BaseEvaluator.expr_regex_match("foo", "foo"))
def test_expr_regex_replace(self):
def test_expr_regex_replace(self) -> None:
"""Test expr_regex_replace"""
self.assertEqual(BaseEvaluator.expr_regex_replace("foo", "o", "a"), "faa")
def test_expr_user_by(self):
def test_expr_user_by(self) -> None:
"""Test expr_user_by"""
user = create_test_admin_user()
self.assertIsNotNone(BaseEvaluator.expr_user_by(username=user.username))
self.assertIsNone(BaseEvaluator.expr_user_by(username="bar"))
self.assertIsNone(BaseEvaluator.expr_user_by(foo="bar"))
def test_expr_is_group_member(self):
def test_expr_is_group_member(self) -> None:
"""Test expr_is_group_member"""
self.assertFalse(BaseEvaluator.expr_is_group_member(create_test_admin_user(), name="test"))
def test_expr_event_create(self):
def test_expr_event_create(self) -> None:
"""Test expr_event_create"""
evaluator = BaseEvaluator(generate_id())
evaluator._context = {
@@ -46,10 +46,11 @@ class TestEvaluator(TestCase):
evaluator.evaluate("ak_create_event('foo', bar='baz')")
event = Event.objects.filter(action="custom_foo").first()
self.assertIsNotNone(event)
assert event is not None # nosec
self.assertEqual(event.context, {"bar": "baz", "foo": "bar"})
@apply_blueprint("system/providers-oauth2.yaml")
def test_expr_create_jwt(self):
def test_expr_create_jwt(self) -> None:
"""Test expr_create_jwt"""
rf = RequestFactory()
user = create_test_user()
@@ -81,7 +82,7 @@ class TestEvaluator(TestCase):
self.assertEqual(decoded["preferred_username"], user.username)
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_body(self, mock_send_mails):
def test_expr_send_email_with_body(self, mock_send_mails: NonCallableMock) -> None:
"""Test ak_send_email with body parameter"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
@@ -108,7 +109,7 @@ class TestEvaluator(TestCase):
self.assertEqual(message.body, "Test Body")
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_template(self, mock_send_mails):
def test_expr_send_email_with_template(self, mock_send_mails: NonCallableMock) -> None:
"""Test ak_send_email with template parameter"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
@@ -123,7 +124,7 @@ class TestEvaluator(TestCase):
self.assertTrue(result)
mock_send_mails.assert_called_once()
def test_expr_send_email_validation_errors(self):
def test_expr_send_email_validation_errors(self) -> None:
"""Test ak_send_email validation errors"""
evaluator = BaseEvaluator(generate_id())
@@ -141,7 +142,7 @@ class TestEvaluator(TestCase):
self.assertIn("Either body or template parameter must be provided", str(cm.exception))
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_custom_stage(self, mock_send_mails):
def test_expr_send_email_with_custom_stage(self, mock_send_mails: NonCallableMock) -> None:
"""Test ak_send_email with custom EmailStage"""
from authentik.stages.email.models import EmailStage
@@ -170,7 +171,7 @@ class TestEvaluator(TestCase):
self.assertFalse(stage.use_global_settings)
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_with_context(self, mock_send_mails):
def test_expr_send_email_with_context(self, mock_send_mails: NonCallableMock) -> None:
"""Test ak_send_email with custom context parameter"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
@@ -199,7 +200,7 @@ class TestEvaluator(TestCase):
self.assertIn("http://localhost", message.body)
@patch("authentik.stages.email.tasks.send_mails")
def test_expr_send_email_multiple_addresses(self, mock_send_mails):
def test_expr_send_email_multiple_addresses(self, mock_send_mails: NonCallableMock) -> None:
"""Test ak_send_email with multiple email addresses"""
user = create_test_user()
evaluator = BaseEvaluator(generate_id())
@@ -226,7 +227,7 @@ class TestEvaluator(TestCase):
self.assertEqual(message.to, ["user1@example.com", "user2@example.com"])
self.assertEqual(message.body, "Test Body")
def test_expr_send_email_multiple_addresses_validation(self):
def test_expr_send_email_multiple_addresses_validation(self) -> None:
"""Test ak_send_email validation with multiple addresses"""
evaluator = BaseEvaluator(generate_id())

View File

@@ -15,27 +15,27 @@ class TestHTTP(TestCase):
self.user = create_test_admin_user()
self.factory = RequestFactory()
def test_bad_request_message(self):
def test_bad_request_message(self) -> None:
"""test bad_request_message"""
request = self.factory.get("/")
self.assertEqual(bad_request_message(request, "foo").status_code, 400)
def test_normal(self):
def test_normal(self) -> None:
"""Test normal request"""
request = self.factory.get("/")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
def test_forward_for(self):
def test_forward_for(self) -> None:
"""Test x-forwarded-for request"""
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
def test_forward_for_invalid(self):
def test_forward_for_invalid(self) -> None:
"""Test invalid forward for"""
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
def test_fake_outpost(self):
def test_fake_outpost(self) -> None:
"""Test faked IP which is overridden by an outpost"""
token = Token.objects.create(
identifier="test", user=self.user, intent=TokenIntents.INTENT_API
@@ -43,7 +43,7 @@ class TestHTTP(TestCase):
# Invalid, non-existent token
request = self.factory.get(
"/",
**{
**{ # type: ignore[arg-type]
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
ClientIPMiddleware.outpost_token_header: "abc",
},
@@ -52,7 +52,7 @@ class TestHTTP(TestCase):
# Invalid, user doesn't have permissions
request = self.factory.get(
"/",
**{
**{ # type: ignore[arg-type]
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
ClientIPMiddleware.outpost_token_header: token.key,
},
@@ -63,7 +63,7 @@ class TestHTTP(TestCase):
self.user.save()
request = self.factory.get(
"/",
**{
**{ # type: ignore[arg-type]
ClientIPMiddleware.outpost_remote_ip_header: "foobar",
ClientIPMiddleware.outpost_token_header: token.key,
},
@@ -74,7 +74,7 @@ class TestHTTP(TestCase):
self.user.save()
request = self.factory.get(
"/",
**{
**{ # type: ignore[arg-type]
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
ClientIPMiddleware.outpost_token_header: token.key,
},

View File

@@ -8,10 +8,10 @@ from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
class TestSentry(TestCase):
"""test sentry integration"""
def test_error_not_sent(self):
def test_error_not_sent(self) -> None:
"""Test SentryIgnoredError not sent"""
self.assertTrue(should_ignore_exception(SentryIgnoredException()))
def test_error_sent(self):
def test_error_sent(self) -> None:
"""Test error sent"""
self.assertFalse(should_ignore_exception(ValueError()))

View File

@@ -5,7 +5,6 @@ from collections.abc import Callable
from django.test import TestCase
from rest_framework.serializers import BaseSerializer
from authentik.flows.models import Stage
from authentik.lib.models import SerializerModel
from authentik.lib.utils.reflection import all_subclasses
@@ -14,10 +13,10 @@ class TestModels(TestCase):
"""Generic model properties tests"""
def model_tester_factory(test_model: type[Stage]) -> Callable:
def model_tester_factory(test_model: type[SerializerModel]) -> Callable[[TestModels], None]:
"""Test a form"""
def tester(self: TestModels):
def tester(self: TestModels) -> None:
try:
model_class = None
if test_model._meta.abstract: # pragma: no cover
@@ -31,4 +30,4 @@ def model_tester_factory(test_model: type[Stage]) -> Callable:
for model in all_subclasses(SerializerModel):
setattr(TestModels, f"test_model_{model.__name__}", model_tester_factory(model))
setattr(TestModels, f"test_model_{model.__name__}", model_tester_factory(model)) # type: ignore[type-abstract]

View File

@@ -10,6 +10,6 @@ from authentik.lib.utils.reflection import path_to_class
class TestReflectionUtils(TestCase):
"""Test Reflection-utils"""
def test_path_to_class(self):
def test_path_to_class(self) -> None:
"""Test path_to_class"""
self.assertEqual(path_to_class("datetime.datetime"), datetime)

View File

@@ -11,20 +11,20 @@ from authentik.lib.utils.time import timedelta_from_string, timedelta_string_val
class TestTimeUtils(TestCase):
"""Test time-utils"""
def test_valid(self):
def test_valid(self) -> None:
"""Test valid expression"""
expr = "hours=3;minutes=1"
expected = timedelta(hours=3, minutes=1)
self.assertEqual(timedelta_from_string(expr), expected)
def test_invalid(self):
def test_invalid(self) -> None:
"""Test invalid expression"""
with self.assertRaises(ValueError):
timedelta_from_string("foo")
with self.assertRaises(ValueError):
timedelta_from_string("bar=baz")
def test_validation(self):
def test_validation(self) -> None:
"""Test Django model field validator"""
with self.assertRaises(ValidationError):
timedelta_string_validator("foo")

View File

@@ -2,23 +2,31 @@
from inspect import currentframe
from pathlib import Path
from typing import Any
from django.contrib.messages.middleware import MessageMiddleware
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import HttpRequest
from django.core.handlers.wsgi import WSGIRequest
from django.http import HttpRequest, HttpResponse
from django.test.client import RequestFactory
from guardian.utils import get_anonymous_user
from authentik.core.models import User
def dummy_get_response(request: HttpRequest): # pragma: no cover
def dummy_get_response(request: HttpRequest) -> HttpResponse: # pragma: no cover
"""Dummy get_response for SessionMiddleware"""
return None
return HttpResponse()
def load_fixture(path: str, **kwargs) -> str:
def load_fixture(path: str, **kwargs: Any) -> str:
"""Load fixture, optionally formatting it with kwargs"""
current = currentframe()
if current is None:
return ""
parent = current.f_back
if parent is None:
return ""
calling_file_path = parent.f_globals["__file__"]
with open(Path(calling_file_path).resolve().parent / Path(path), encoding="utf-8") as _fixture:
fixture = _fixture.read()
@@ -28,17 +36,17 @@ def load_fixture(path: str, **kwargs) -> str:
return fixture
def get_request(*args, user=None, **kwargs):
def get_request(*args: Any, user: User | None = None, **kwargs: Any) -> WSGIRequest:
"""Get a request with usable session"""
request = RequestFactory().get(*args, **kwargs)
if user:
if user is not None:
request.user = user
else:
request.user = get_anonymous_user()
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(request)
session_middleware = SessionMiddleware(dummy_get_response)
session_middleware.process_request(request)
request.session.save()
middleware = MessageMiddleware(dummy_get_response)
middleware.process_request(request)
message_middleware = MessageMiddleware(dummy_get_response)
message_middleware.process_request(request)
request.session.save()
return request

View File

@@ -1,16 +1,22 @@
"""authentik database utilities"""
import gc
from collections.abc import Generator
from typing import TypeVar
from django.db import reset_queries
from django.db.models import QuerySet
from django.db.models import Model, QuerySet
ModelT_co = TypeVar("ModelT_co", bound=Model, covariant=True)
def chunked_queryset(queryset: QuerySet, chunk_size: int = 1_000):
def chunked_queryset(
queryset: QuerySet[ModelT_co], chunk_size: int = 1_000
) -> Generator[ModelT_co]:
if not queryset.exists():
return []
return
def get_chunks(qs: QuerySet):
def get_chunks(qs: QuerySet[ModelT_co]) -> Generator[QuerySet[ModelT_co]]:
qs = qs.order_by("pk")
pks = qs.values_list("pk", flat=True)
start_pk = pks[0]

View File

@@ -1,7 +1,14 @@
from typing import Any
from typing import Any, cast
type rdict[R] = dict[str, "rdict[R] | R"]
def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
def get_path_from_dict[R: Any](
root: rdict[R],
path: str,
sep: str = ".",
default: R | None = None,
) -> Any | None:
"""Recursively walk through `root`, checking each part of `path` separated by `sep`.
If at any point a dict does not exist, return default"""
walk: Any = root
@@ -10,10 +17,10 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
walk = walk.get(comp)
else:
return default
return walk
return cast(R, walk)
def set_path_in_dict(root: dict, path: str, value: Any, sep="."):
def set_path_in_dict[R: Any](root: rdict[R], path: str, value: R, sep: str = ".") -> None:
"""Recursively walk through `root`, checking each part of `path` separated by `sep`
and setting the last value to `value`"""
# Walk each component of the path

View File

@@ -1,7 +1,7 @@
"""file utils"""
from django.db.models import Model
from django.http import HttpResponseBadRequest
from django.http import HttpResponse, HttpResponseBadRequest
from rest_framework.fields import BooleanField, CharField, FileField
from rest_framework.request import Request
from rest_framework.response import Response
@@ -25,7 +25,7 @@ class FilePathSerializer(PassiveSerializer):
url = CharField()
def set_file(request: Request, obj: Model, field_name: str):
def set_file(request: Request, obj: Model, field_name: str) -> HttpResponse:
"""Upload file"""
field = getattr(obj, field_name)
file = request.FILES.get("file", None)
@@ -45,7 +45,7 @@ def set_file(request: Request, obj: Model, field_name: str):
return HttpResponseBadRequest()
def set_file_url(request: Request, obj: Model, field_name: str):
def set_file_url(request: Request, obj: Model, field_name: str) -> HttpResponse:
"""Set file field to URL"""
field = getattr(obj, field_name)
url = request.data.get("url", None)

View File

@@ -1,13 +1,18 @@
"""http helpers"""
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from requests.models import Response
from requests.sessions import PreparedRequest, Session
from structlog.stdlib import get_logger
from authentik import authentik_full_version
from authentik.lib.config import CONFIG
if TYPE_CHECKING:
from requests.sessions import _Timeout
LOGGER = get_logger()
@@ -19,50 +24,40 @@ def authentik_user_agent() -> str:
class TimeoutSession(Session):
"""Always set a default HTTP request timeout"""
def __init__(self, default_timeout=None):
def __init__(self, default_timeout: int | None = None) -> None:
super().__init__()
self.timeout = default_timeout
def send(
self,
request,
request: PreparedRequest,
*,
stream=...,
verify=...,
proxies=...,
cert=...,
timeout=...,
allow_redirects=...,
**kwargs,
):
timeout: "_Timeout | None" = None,
**kwargs: Any,
) -> Response:
if not timeout and self.timeout:
timeout = self.timeout
return super().send(
request,
stream=stream,
verify=verify,
proxies=proxies,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
**kwargs,
)
return super().send(request, timeout=timeout, **kwargs)
class DebugSession(TimeoutSession):
"""requests session which logs http requests and responses"""
def send(self, req: PreparedRequest, *args, **kwargs):
def send(
self,
request: PreparedRequest,
**kwargs: Any,
) -> Response:
request_id = str(uuid4())
LOGGER.debug(
"HTTP request sent",
uid=request_id,
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
url=request.url,
method=request.method,
headers=request.headers,
body=request.body,
)
resp = super().send(req, *args, **kwargs)
resp = super().send(request, **kwargs)
LOGGER.debug(
"HTTP response received",
uid=request_id,

View File

@@ -1,10 +1,13 @@
"""authentik lib reflection utilities"""
import os
from collections.abc import Generator
from importlib import import_module
from pathlib import Path
from tempfile import gettempdir
from typing import cast
from django.apps.config import AppConfig
from django.conf import settings
from django.utils.module_loading import import_string
@@ -13,9 +16,9 @@ from authentik.lib.config import CONFIG
SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST"
def all_subclasses[T: type](cls: T, sort=True) -> list[T] | set[T]:
def all_subclasses[T: type](cls: T, sort: bool = True) -> list[T] | set[T]:
"""Recursively return all subclassess of cls"""
classes = set(cls.__subclasses__()).union(
classes: list[T] | set[T] = set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)]
)
# Check if we're in debug mode, if not exclude classes which have `__debug_only__`
@@ -38,10 +41,10 @@ def path_to_class(path: str = "") -> type:
parts = path.split(".")
package = ".".join(parts[:-1])
_class = getattr(import_module(package), parts[-1])
return _class
return cast(type, _class)
def get_apps():
def get_apps() -> Generator[AppConfig]:
"""Get list of all authentik apps"""
from django.apps.registry import apps
@@ -65,11 +68,11 @@ def get_env() -> str:
return "custom"
def ConditionalInheritance(path: str):
def ConditionalInheritance(path: str) -> type:
"""Conditionally inherit from a class, intended for things like authentik.enterprise,
without which authentik should still be able to run"""
try:
cls = import_string(path)
return cls
return cast(type, cls)
except ModuleNotFoundError:
return object

View File

@@ -19,7 +19,7 @@ ALLOWED_KEYS = (
)
def timedelta_string_validator(value: str):
def timedelta_string_validator(value: str) -> None:
"""Validator for Django that checks if value can be parsed with `timedelta_from_string`"""
try:
timedelta_from_string(value)

View File

@@ -1,5 +1,6 @@
"""URL-related utils"""
from typing import Any
from urllib.parse import urlparse
from django.http import HttpResponse, QueryDict
@@ -10,12 +11,12 @@ from structlog.stdlib import get_logger
LOGGER = get_logger()
def is_url_absolute(url):
def is_url_absolute(url: str | bytes | bytearray | None) -> bool:
"""Check if domain is absolute to prevent user from being redirect somewhere else"""
return bool(urlparse(url).netloc)
def redirect_with_qs(view: str, get_query_set: QueryDict | None = None, **kwargs) -> HttpResponse:
def redirect_with_qs(view: str, qs: QueryDict | None = None, **kwargs: Any) -> HttpResponse:
"""Wrapper to redirect whilst keeping GET Parameters"""
try:
target = reverse(view, kwargs=kwargs)
@@ -24,14 +25,14 @@ def redirect_with_qs(view: str, get_query_set: QueryDict | None = None, **kwargs
return redirect(view)
LOGGER.warning("redirect target is not a valid view", view=view)
raise
if get_query_set:
target += "?" + get_query_set.urlencode()
if qs:
target += "?" + qs.urlencode()
return redirect(target)
def reverse_with_qs(view: str, query: QueryDict | None = None, **kwargs) -> str:
def reverse_with_qs(view: str, qs: QueryDict | None = None, **kwargs: Any) -> str:
"""Reverse a view to it's url but include get params"""
url = reverse(view, **kwargs)
if query:
url += "?" + query.urlencode()
if qs:
url += "?" + qs.urlencode()
return url

View File

@@ -1,10 +1,17 @@
"""Serializer validators"""
from typing import TYPE_CHECKING, Any, TypeVar
from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import Serializer
from rest_framework.utils.representation import smart_repr
if TYPE_CHECKING:
from django.utils.functional import _StrPromise
_IN = TypeVar("_IN") # Instance Type
class RequiredTogetherValidator:
"""Serializer-level validator that ensures all fields in `fields` are only
@@ -12,13 +19,13 @@ class RequiredTogetherValidator:
fields: list[str]
requires_context = True
message = _("The fields {field_names} must be used together.")
message: "str | _StrPromise" = _("The fields {field_names} must be used together.")
def __init__(self, fields: list[str], message: str | None = None) -> None:
def __init__(self, fields: list[str], message: "str | _StrPromise | None" = None) -> None:
self.fields = fields
self.message = message or self.message
def __call__(self, attrs: dict, serializer: Serializer):
def __call__(self, attrs: dict[Any, Any], serializer: Serializer[_IN]) -> None:
"""Check that if any of the fields in `self.fields` are set, all of them must be set"""
if any(field in attrs for field in self.fields) and not all(
field in attrs for field in self.fields
@@ -27,5 +34,5 @@ class RequiredTogetherValidator:
message = self.message.format(field_names=field_names)
raise ValidationError(message, code="required")
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>"

View File

@@ -8,8 +8,8 @@ from django.utils.translation import gettext_lazy as _
def bad_request_message(
request: HttpRequest,
message: str,
title="Bad Request",
template="if/error.html",
title: str = "Bad Request",
template: str = "if/error.html",
) -> TemplateResponse:
"""Return generic error page with message, with status code set to 400"""
return TemplateResponse(

View File

@@ -1,13 +1,13 @@
"""XML Utilities"""
from lxml.etree import XMLParser, fromstring # nosec
from lxml.etree import XMLParser, _Element, fromstring # nosec
def get_lxml_parser():
def get_lxml_parser() -> XMLParser:
"""Get XML parser"""
return XMLParser(resolve_entities=False)
def lxml_from_string(text: str):
def lxml_from_string(text: str) -> _Element:
"""Wrapper around fromstring"""
return fromstring(text, parser=get_lxml_parser()) # nosec

View File

@@ -5,7 +5,6 @@ from multiprocessing.connection import Connection
from django.core.cache import cache
from sentry_sdk import start_span
from sentry_sdk.tracing import Span
from structlog.stdlib import get_logger
from authentik.events.models import Event, EventAction
@@ -121,7 +120,7 @@ class PolicyProcess(PROCESS_CLASS):
)
return policy_result
def profiling_wrapper(self):
def profiling_wrapper(self) -> PolicyResult:
"""Run with profiling enabled"""
with (
start_span(
@@ -135,7 +134,6 @@ class PolicyProcess(PROCESS_CLASS):
mode="execute_process",
).time(),
):
span: Span
span.set_data("policy", self.binding.policy)
span.set_data("request", self.request)
return self.execute()

View File

@@ -458,7 +458,7 @@ class BaseGrantModel(models.Model):
return self._scope.split()
@scope.setter
def scope(self, value):
def scope(self, value: list[str]) -> None:
self._scope = " ".join(value)

View File

@@ -1,12 +1,8 @@
"""SCIMProviderGroup API Views"""
from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.users import PartialGroupSerializer
from authentik.core.api.utils import ModelSerializer
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
from authentik.providers.scim.models import SCIMProviderGroup
@@ -16,7 +12,6 @@ class SCIMProviderGroupSerializer(ModelSerializer):
group_obj = PartialGroupSerializer(source="group", read_only=True)
class Meta:
model = SCIMProviderGroup
fields = [
"id",
@@ -29,15 +24,7 @@ class SCIMProviderGroupSerializer(ModelSerializer):
extra_kwargs = {"attributes": {"read_only": True}}
class SCIMProviderGroupViewSet(
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
class SCIMProviderGroupViewSet(OutgoingSyncConnectionViewSet):
"""SCIMProviderGroup Viewset"""
queryset = SCIMProviderGroup.objects.all().select_related("group")

View File

@@ -1,10 +1,7 @@
"""SCIM Provider API Views"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderViewSet
from authentik.lib.utils.reflection import ConditionalInheritance
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
@@ -45,13 +42,16 @@ class SCIMProviderSerializer(
extra_kwargs = {}
class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
class SCIMProviderViewSet(OutgoingSyncProviderViewSet):
"""SCIMProvider Viewset"""
queryset = SCIMProvider.objects.all()
serializer_class = SCIMProviderSerializer
filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
search_fields = ["name", "url"]
ordering = ["name", "url"]
filterset_fields = OutgoingSyncProviderViewSet.filterset_fields + [
"url",
]
search_fields = OutgoingSyncProviderViewSet.search_fields + [
"url",
]
sync_task = scim_sync
sync_objects_task = scim_sync_objects

View File

@@ -1,12 +1,8 @@
"""SCIMProviderUser API Views"""
from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.groups import PartialUserSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
from authentik.providers.scim.models import SCIMProviderUser
@@ -16,7 +12,6 @@ class SCIMProviderUserSerializer(ModelSerializer):
user_obj = PartialUserSerializer(source="user", read_only=True)
class Meta:
model = SCIMProviderUser
fields = [
"id",
@@ -29,15 +24,7 @@ class SCIMProviderUserSerializer(ModelSerializer):
extra_kwargs = {"attributes": {"read_only": True}}
class SCIMProviderUserViewSet(
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
class SCIMProviderUserViewSet(OutgoingSyncConnectionViewSet):
"""SCIMProviderUser Viewset"""
queryset = SCIMProviderUser.objects.all().select_related("user")

View File

@@ -12,7 +12,7 @@ from requests.auth import AuthBase
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
from authentik.core.models import Group, PropertyMapping, User, UserTypes
from authentik.lib.models import SerializerModel
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
@@ -81,7 +81,7 @@ class SCIMCompatibilityMode(models.TextChoices):
SALESFORCE = "sfdc", _("Salesforce")
class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
class SCIMProvider(OutgoingSyncProvider):
"""SCIM 2.0 provider to create users and groups in external applications"""
exclude_users_service_account = models.BooleanField(default=False)

View File

@@ -253,19 +253,19 @@ class IdentificationStageView(ChallengeStageView):
if current_stage.enrollment_flow:
challenge.initial_data["enroll_url"] = reverse_with_qs(
"authentik_core:if-flow",
query=get_qs,
qs=get_qs,
kwargs={"flow_slug": current_stage.enrollment_flow.slug},
)
if current_stage.recovery_flow:
challenge.initial_data["recovery_url"] = reverse_with_qs(
"authentik_core:if-flow",
query=get_qs,
qs=get_qs,
kwargs={"flow_slug": current_stage.recovery_flow.slug},
)
if current_stage.passwordless_flow:
challenge.initial_data["passwordless_url"] = reverse_with_qs(
"authentik_core:if-flow",
query=get_qs,
qs=get_qs,
kwargs={"flow_slug": current_stage.passwordless_flow.slug},
)

View File

@@ -104,7 +104,9 @@ dev = [
"requests-mock==1.12.1",
"ruff==0.11.9",
"selenium==4.32.0",
"types-cachetools==6.2.0.20251022",
"types-channels==4.3.0.20250822",
"types-docker==7.1.0.20251009",
"types-ldap3==2.9.13.20250622",
]
@@ -208,7 +210,7 @@ plugins = ["mypy_django_plugin.main", "mypy_drf_plugin.main", "pydantic.mypy"]
exclude = ['^gen-py-api/']
[[tool.mypy.overrides]]
module = ["django_tenants.*", "dramatiq.*", "pglock.*"]
module = ["django_tenants.*", "dramatiq.*", "pglock.*", "debugpy.*"]
follow_untyped_imports = true
[[tool.mypy.overrides]]
@@ -226,7 +228,9 @@ module = [
"authentik.enterprise.*",
"authentik.events.*",
"authentik.flows.*",
"authentik.lib.*",
"authentik.lib.config",
"authentik.lib.models",
"authentik.lib.tests.test_config",
"authentik.outposts.*",
"authentik.policies.*",
"authentik.policies.dummy.*",

39
uv.lock generated
View File

@@ -272,7 +272,9 @@ dev = [
{ name = "requests-mock" },
{ name = "ruff" },
{ name = "selenium" },
{ name = "types-cachetools" },
{ name = "types-channels" },
{ name = "types-docker" },
{ name = "types-ldap3" },
]
@@ -377,7 +379,9 @@ dev = [
{ name = "requests-mock", specifier = "==1.12.1" },
{ name = "ruff", specifier = "==0.11.9" },
{ name = "selenium", specifier = "==4.32.0" },
{ name = "types-cachetools", specifier = "==6.2.0.20251022" },
{ name = "types-channels", specifier = "==4.3.0.20250822" },
{ name = "types-docker", specifier = "==7.1.0.20251009" },
{ name = "types-ldap3", specifier = "==2.9.13.20250622" },
]
@@ -3389,6 +3393,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9a/bb/d43e5c75054e53efce310e79d63df0ac3f25e34c926be5dffb7d283fb2a8/typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1", size = 17605, upload-time = "2021-12-10T21:09:37.844Z" },
]
[[package]]
name = "types-cachetools"
version = "6.2.0.20251022"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/3b/a8/f9bcc7f1be63af43ef0170a773e2d88817bcc7c9d8769f2228c802826efe/types_cachetools-6.2.0.20251022.tar.gz", hash = "sha256:f1d3c736f0f741e89ec10f0e1b0138625023e21eb33603a930c149e0318c0cef", size = 9608, upload-time = "2025-10-22T03:03:58.16Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/98/2d/8d821ed80f6c2c5b427f650bf4dc25b80676ed63d03388e4b637d2557107/types_cachetools-6.2.0.20251022-py3-none-any.whl", hash = "sha256:698eb17b8f16b661b90624708b6915f33dbac2d185db499ed57e4997e7962cad", size = 9341, upload-time = "2025-10-22T03:03:57.036Z" },
]
[[package]]
name = "types-channels"
version = "4.3.0.20250822"
@@ -3402,6 +3415,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/52/4e3094e43d460feacb9051ec4c3498f8272f69d92b772647211478b25079/types_channels-4.3.0.20250822-py3-none-any.whl", hash = "sha256:d3fc0a1467c8cc901686826408c8a673822e07aa79cbe1a6d21946e7e55d9ddf", size = 21125, upload-time = "2025-08-22T03:04:25.539Z" },
]
[[package]]
name = "types-docker"
version = "7.1.0.20251009"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "types-paramiko" },
{ name = "types-requests" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/93/9d/c9427adb04df55f3821b042612f8f7555c7060d6a3b589549a10b7a15c3a/types_docker-7.1.0.20251009.tar.gz", hash = "sha256:37af2a9ed5c3d76308ee9b9958cf1506fe9bcfbfed9c0a20bd9856dbca90424e", size = 31647, upload-time = "2025-10-09T02:54:40.976Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/71/bb/da793021a87528e2ca717f117d0d94149c429425a33dea036675932a0170/types_docker-7.1.0.20251009-py3-none-any.whl", hash = "sha256:e0ed83c70b824d0efffca6e61662e2722109207515579782fa27c505ea06fb7d", size = 46417, upload-time = "2025-10-09T02:54:40.035Z" },
]
[[package]]
name = "types-ldap3"
version = "2.9.13.20250622"
@@ -3414,6 +3441,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/fd/0339a618985d19d9b0630f78822d1becb0661be6abe8adbadd9569b875e1/types_ldap3-2.9.13.20250622-py3-none-any.whl", hash = "sha256:c18d0320327fa0017eb3d95acdf38921542d80939255e4ba130ca2d13ca3375f", size = 56498, upload-time = "2025-06-22T03:19:15.495Z" },
]
[[package]]
name = "types-paramiko"
version = "4.0.0.20250822"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b7/b8/c6ff3b10c2f7b9897650af746f0dc6c5cddf054db857bc79d621f53c7d22/types_paramiko-4.0.0.20250822.tar.gz", hash = "sha256:1b56b0cbd3eec3d2fd123c9eb2704e612b777e15a17705a804279ea6525e0c53", size = 28730, upload-time = "2025-08-22T03:03:43.262Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/79/a1/b3774ed924a66ee2c041224d89c36f0c21f4f6cf75036d6ee7698bf8a4b9/types_paramiko-4.0.0.20250822-py3-none-any.whl", hash = "sha256:55bdb14db75ca89039725ec64ae3fa26b8d57b6991cfb476212fa8f83a59753c", size = 38833, upload-time = "2025-08-22T03:03:42.072Z" },
]
[[package]]
name = "types-pyasn1"
version = "0.6.0.20250914"