mirror of
https://github.com/goauthentik/authentik
synced 2026-05-06 07:02:51 +02:00
Compare commits
20 Commits
docs/invit
...
lib-typing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1588e6d130 | ||
|
|
fc0366b3f4 | ||
|
|
db849599f5 | ||
|
|
267f9d9905 | ||
|
|
2e62d7cb14 | ||
|
|
c4adff1b26 | ||
|
|
adcad1350d | ||
|
|
94f64882ab | ||
|
|
e0b592c035 | ||
|
|
02ae7eada7 | ||
|
|
c12749e3e9 | ||
|
|
171e83b088 | ||
|
|
219666c32d | ||
|
|
e9ec83fd03 | ||
|
|
123cca34a1 | ||
|
|
3a2559b115 | ||
|
|
a9b50c8c77 | ||
|
|
b81d415faf | ||
|
|
58d5d37953 | ||
|
|
cc71bb6e74 |
@@ -2,23 +2,27 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from django.apps import apps
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.blueprints.models import BlueprintInstance
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def apply_blueprint(*files: str):
|
||||
|
||||
def apply_blueprint(*files: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Apply blueprint before test"""
|
||||
|
||||
from authentik.blueprints.v1.importer import Importer
|
||||
|
||||
def wrapper_outer(func: Callable):
|
||||
def wrapper_outer(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Apply blueprint before test"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
for file in files:
|
||||
content = BlueprintInstance(path=file).retrieve()
|
||||
Importer.from_string(content).apply()
|
||||
|
||||
@@ -3,12 +3,11 @@
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Model
|
||||
from django.http import HttpRequest
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import PropertyMapping, User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.policies.types import PolicyRequest
|
||||
@@ -23,13 +22,13 @@ PROPERTY_MAPPING_TIME = Histogram(
|
||||
class PropertyMappingEvaluator(BaseEvaluator):
|
||||
"""Custom Evaluator that adds some different context variables."""
|
||||
|
||||
dry_run: bool
|
||||
model: Model
|
||||
dry_run: bool | None
|
||||
model: PropertyMapping
|
||||
_compiled: CodeType | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
model: PropertyMapping,
|
||||
user: User | None = None,
|
||||
request: HttpRequest | None = None,
|
||||
dry_run: bool | None = False,
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""GoogleWorkspaceProviderGroup API Views"""
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import PartialGroupSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
|
||||
@@ -16,7 +12,6 @@ class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
|
||||
group_obj = PartialGroupSerializer(source="group", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = GoogleWorkspaceProviderGroup
|
||||
fields = [
|
||||
"id",
|
||||
@@ -29,15 +24,7 @@ class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderGroupViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
OutgoingSyncConnectionCreateMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class GoogleWorkspaceProviderGroupViewSet(OutgoingSyncConnectionViewSet):
|
||||
"""GoogleWorkspaceProviderGroup Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group")
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
"""Google Provider API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import (
|
||||
google_workspace_sync,
|
||||
google_workspace_sync_objects,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderViewSet
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
|
||||
@@ -44,18 +41,16 @@ class GoogleWorkspaceProviderSerializer(EnterpriseRequiredMixin, ProviderSeriali
|
||||
extra_kwargs = {}
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
|
||||
class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderViewSet):
|
||||
"""GoogleWorkspaceProvider Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProvider.objects.all()
|
||||
serializer_class = GoogleWorkspaceProviderSerializer
|
||||
filterset_fields = [
|
||||
"name",
|
||||
"exclude_users_service_account",
|
||||
filterset_fields = OutgoingSyncProviderViewSet.filterset_fields + [
|
||||
"delegated_subject",
|
||||
]
|
||||
search_fields = OutgoingSyncProviderViewSet.search_fields + [
|
||||
"delegated_subject",
|
||||
"filter_group",
|
||||
]
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_task = google_workspace_sync
|
||||
sync_objects_task = google_workspace_sync_objects
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""GoogleWorkspaceProviderUser API Views"""
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import PartialUserSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
|
||||
@@ -16,7 +12,6 @@ class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
|
||||
user_obj = PartialUserSerializer(source="user", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = GoogleWorkspaceProviderUser
|
||||
fields = [
|
||||
"id",
|
||||
@@ -29,15 +24,7 @@ class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderUserViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
OutgoingSyncConnectionCreateMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class GoogleWorkspaceProviderUserViewSet(OutgoingSyncConnectionViewSet):
|
||||
"""GoogleWorkspaceProviderUser Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user")
|
||||
|
||||
@@ -12,7 +12,6 @@ from google.oauth2.service_account import Credentials
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import (
|
||||
BackchannelProvider,
|
||||
Group,
|
||||
PropertyMapping,
|
||||
User,
|
||||
@@ -84,7 +83,7 @@ class GoogleWorkspaceProviderGroup(SerializerModel):
|
||||
return f"Google Workspace Provider Group {self.group_id} to {self.provider_id}"
|
||||
|
||||
|
||||
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
class GoogleWorkspaceProvider(OutgoingSyncProvider):
|
||||
"""Sync users from authentik into Google Workspace."""
|
||||
|
||||
delegated_subject = models.EmailField()
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""MicrosoftEntraProviderGroup API Views"""
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import PartialGroupSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
|
||||
|
||||
|
||||
class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
|
||||
@@ -16,7 +12,6 @@ class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
|
||||
group_obj = PartialGroupSerializer(source="group", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = MicrosoftEntraProviderGroup
|
||||
fields = [
|
||||
"id",
|
||||
@@ -29,15 +24,7 @@ class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class MicrosoftEntraProviderGroupViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
OutgoingSyncConnectionCreateMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class MicrosoftEntraProviderGroupViewSet(OutgoingSyncConnectionViewSet):
|
||||
"""MicrosoftEntraProviderGroup Viewset"""
|
||||
|
||||
queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group")
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
"""Microsoft Provider API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import (
|
||||
microsoft_entra_sync,
|
||||
microsoft_entra_sync_objects,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderViewSet
|
||||
|
||||
|
||||
class MicrosoftEntraProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
|
||||
@@ -43,17 +40,10 @@ class MicrosoftEntraProviderSerializer(EnterpriseRequiredMixin, ProviderSerializ
|
||||
extra_kwargs = {}
|
||||
|
||||
|
||||
class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
|
||||
class MicrosoftEntraProviderViewSet(OutgoingSyncProviderViewSet):
|
||||
"""MicrosoftEntraProvider Viewset"""
|
||||
|
||||
queryset = MicrosoftEntraProvider.objects.all()
|
||||
serializer_class = MicrosoftEntraProviderSerializer
|
||||
filterset_fields = [
|
||||
"name",
|
||||
"exclude_users_service_account",
|
||||
"filter_group",
|
||||
]
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_task = microsoft_entra_sync
|
||||
sync_objects_task = microsoft_entra_sync_objects
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""MicrosoftEntraProviderUser API Views"""
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import PartialUserSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
|
||||
|
||||
|
||||
class MicrosoftEntraProviderUserSerializer(ModelSerializer):
|
||||
@@ -16,7 +12,6 @@ class MicrosoftEntraProviderUserSerializer(ModelSerializer):
|
||||
user_obj = PartialUserSerializer(source="user", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = MicrosoftEntraProviderUser
|
||||
fields = [
|
||||
"id",
|
||||
@@ -29,15 +24,7 @@ class MicrosoftEntraProviderUserSerializer(ModelSerializer):
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class MicrosoftEntraProviderUserViewSet(
|
||||
OutgoingSyncConnectionCreateMixin,
|
||||
mixins.CreateModelMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class MicrosoftEntraProviderUserViewSet(OutgoingSyncConnectionViewSet):
|
||||
"""MicrosoftEntraProviderUser Viewset"""
|
||||
|
||||
queryset = MicrosoftEntraProviderUser.objects.all().select_related("user")
|
||||
|
||||
@@ -12,7 +12,6 @@ from dramatiq.actor import Actor
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import (
|
||||
BackchannelProvider,
|
||||
Group,
|
||||
PropertyMapping,
|
||||
User,
|
||||
@@ -75,7 +74,7 @@ class MicrosoftEntraProviderGroup(SerializerModel):
|
||||
return f"Microsoft Entra Provider Group {self.group_id} to {self.provider_id}"
|
||||
|
||||
|
||||
class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
class MicrosoftEntraProvider(OutgoingSyncProvider):
|
||||
"""Sync users from authentik into Microsoft Entra."""
|
||||
|
||||
client_id = models.TextField()
|
||||
|
||||
@@ -237,7 +237,7 @@ class Event(SerializerModel, ExpiringModel):
|
||||
self.save()
|
||||
return self
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
if self._state.adding:
|
||||
LOGGER.info(
|
||||
"Created Event",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from base64 import b64encode
|
||||
from functools import cache as funccache
|
||||
from hashlib import md5, sha256
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.core.cache import cache
|
||||
@@ -27,7 +27,7 @@ CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available"
|
||||
GRAVATAR_STATUS_TTL_SECONDS = 60 * 60 * 8 # 8 Hours
|
||||
|
||||
SVG_XML_NS = "http://www.w3.org/2000/svg"
|
||||
SVG_NS_MAP = {None: SVG_XML_NS}
|
||||
SVG_NS_MAP: dict[str, str] = cast(dict[str, str], {None: SVG_XML_NS})
|
||||
# Match fonts used in web UI
|
||||
SVG_FONTS = [
|
||||
"'RedHatText'",
|
||||
@@ -39,7 +39,7 @@ SVG_FONTS = [
|
||||
]
|
||||
|
||||
|
||||
def avatar_mode_none(user: "User", mode: str) -> str | None:
|
||||
def avatar_mode_none(user: "User", mode: str) -> str:
|
||||
"""No avatar"""
|
||||
return DEFAULT_AVATAR
|
||||
|
||||
@@ -62,7 +62,7 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
|
||||
full_key = CACHE_KEY_GRAVATAR + mail_hash
|
||||
if cache.has_key(full_key):
|
||||
cache.touch(full_key)
|
||||
return cache.get(full_key)
|
||||
return cast(str | None, cache.get(full_key))
|
||||
|
||||
try:
|
||||
# Since we specify a default of 404, do a HEAD request
|
||||
@@ -129,16 +129,16 @@ def generate_avatar_from_name(
|
||||
bg_hex, text_hex = generate_colors(name)
|
||||
|
||||
half_size = size // 2
|
||||
shape = "circle" if rounded else "rect"
|
||||
shape_type = "circle" if rounded else "rect"
|
||||
font_weight = "600" if bold else "400"
|
||||
|
||||
root_element: Element = Element(f"{{{SVG_XML_NS}}}svg", nsmap=SVG_NS_MAP)
|
||||
root_element = Element(f"{{{SVG_XML_NS}}}svg", nsmap=SVG_NS_MAP)
|
||||
root_element.attrib["width"] = f"{size}px"
|
||||
root_element.attrib["height"] = f"{size}px"
|
||||
root_element.attrib["viewBox"] = f"0 0 {size} {size}"
|
||||
root_element.attrib["version"] = "1.1"
|
||||
|
||||
shape = SubElement(root_element, f"{{{SVG_XML_NS}}}{shape}", nsmap=SVG_NS_MAP)
|
||||
shape = SubElement(root_element, f"{{{SVG_XML_NS}}}{shape_type}", nsmap=SVG_NS_MAP)
|
||||
shape.attrib["fill"] = f"#{bg_hex}"
|
||||
shape.attrib["cx"] = f"{half_size}"
|
||||
shape.attrib["cy"] = f"{half_size}"
|
||||
@@ -150,7 +150,7 @@ def generate_avatar_from_name(
|
||||
text.attrib["x"] = "50%"
|
||||
text.attrib["y"] = "50%"
|
||||
text.attrib["style"] = (
|
||||
f"color: #{text_hex}; " "line-height: 1; " f"font-family: {','.join(SVG_FONTS)}; "
|
||||
f"color: #{text_hex}; line-height: 1; font-family: {','.join(SVG_FONTS)}; "
|
||||
)
|
||||
text.attrib["fill"] = f"#{text_hex}"
|
||||
text.attrib["alignment-baseline"] = "middle"
|
||||
@@ -197,7 +197,7 @@ def get_avatar(user: "User", request: HttpRequest | None = None) -> str:
|
||||
}
|
||||
tenant = None
|
||||
if request:
|
||||
tenant = request.tenant
|
||||
tenant = request.tenant # type: ignore[attr-defined]
|
||||
else:
|
||||
tenant = get_current_tenant()
|
||||
modes: str = tenant.avatars
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
@@ -5,11 +7,11 @@ from authentik.lib.config import CONFIG
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def start_debug_server(**kwargs) -> bool:
|
||||
def start_debug_server(**kwargs: Any) -> bool:
|
||||
"""Attempt to start a debugpy server in the current process.
|
||||
Returns true if the server was started successfully, otherwise false"""
|
||||
if not CONFIG.get_bool("debug") and not CONFIG.get_bool("debugger"):
|
||||
return
|
||||
return False
|
||||
try:
|
||||
import debugpy
|
||||
except ImportError:
|
||||
|
||||
@@ -13,10 +13,9 @@ from django.core.exceptions import FieldError
|
||||
from django.http import HttpRequest
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from guardian.utils import get_anonymous_user
|
||||
from rest_framework.serializers import ValidationError
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import User
|
||||
@@ -55,7 +54,7 @@ class BaseEvaluator:
|
||||
# Filename used for exec
|
||||
_filename: str
|
||||
|
||||
def __init__(self, filename: str | None = None):
|
||||
def __init__(self, filename: str | None = None) -> None:
|
||||
self._filename = filename if filename else "BaseEvaluator"
|
||||
# update website/docs/expressions/_objects.md
|
||||
# update website/docs/expressions/_functions.md
|
||||
@@ -133,12 +132,12 @@ class BaseEvaluator:
|
||||
return re.sub(regex, repl, value)
|
||||
|
||||
@staticmethod
|
||||
def expr_is_group_member(user: User, **group_filters) -> bool:
|
||||
def expr_is_group_member(user: User, **group_filters: Any) -> bool:
|
||||
"""Check if `user` is member of group with name `group_name`"""
|
||||
return user.all_groups().filter(**group_filters).exists()
|
||||
|
||||
@staticmethod
|
||||
def expr_user_by(**filters) -> User | None:
|
||||
def expr_user_by(**filters: Any) -> User | None:
|
||||
"""Get user by filters"""
|
||||
try:
|
||||
users = User.objects.filter(**filters)
|
||||
@@ -160,7 +159,7 @@ class BaseEvaluator:
|
||||
return False
|
||||
return len(list(user_devices)) > 0
|
||||
|
||||
def expr_event_create(self, action: str, **kwargs):
|
||||
def expr_event_create(self, action: str, **kwargs: Any) -> None:
|
||||
"""Create event with supplied data and try to extract as much relevant data
|
||||
from the context"""
|
||||
context = self._context.copy()
|
||||
@@ -181,7 +180,7 @@ class BaseEvaluator:
|
||||
return
|
||||
event.save()
|
||||
|
||||
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
|
||||
def expr_func_call_policy(self, name: str, **kwargs: Any) -> PolicyResult:
|
||||
"""Call policy by name, with current request"""
|
||||
policy = Policy.objects.filter(name=name).select_subclasses().first()
|
||||
if not policy:
|
||||
@@ -214,10 +213,10 @@ class BaseEvaluator:
|
||||
provider=provider,
|
||||
user=user,
|
||||
expires=now() + timedelta_from_string(validity),
|
||||
scope=scopes,
|
||||
auth_time=now(),
|
||||
session=session,
|
||||
)
|
||||
access_token.scope = scopes
|
||||
access_token.id_token = IDToken.new(provider, access_token, request)
|
||||
access_token.save()
|
||||
return access_token.token
|
||||
@@ -229,7 +228,7 @@ class BaseEvaluator:
|
||||
body: str | None = None,
|
||||
stage: "EmailStage | None" = None,
|
||||
template: str | None = None,
|
||||
context: dict | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Send an email using authentik's email system
|
||||
|
||||
@@ -316,7 +315,6 @@ class BaseEvaluator:
|
||||
If any exception is raised during execution, it is raised.
|
||||
The result is returned without any type-checking."""
|
||||
with start_span(op="authentik.lib.evaluator.evaluate") as span:
|
||||
span: Span
|
||||
span.description = self._filename
|
||||
span.set_data("expression", expression_source)
|
||||
try:
|
||||
@@ -343,7 +341,7 @@ class BaseEvaluator:
|
||||
raise exc
|
||||
return result
|
||||
|
||||
def handle_error(self, exc: Exception, expression_source: str): # pragma: no cover
|
||||
def handle_error(self, exc: Exception, expression_source: str) -> None: # pragma: no cover
|
||||
"""Exception Handler"""
|
||||
LOGGER.warning("Expression error", exc=exc)
|
||||
|
||||
|
||||
@@ -4,20 +4,20 @@ import string
|
||||
from random import SystemRandom
|
||||
|
||||
|
||||
def generate_code_fixed_length(length=9) -> str:
|
||||
def generate_code_fixed_length(length: int = 9) -> str:
|
||||
"""Generate a numeric code"""
|
||||
rand = SystemRandom()
|
||||
num = rand.randrange(1, 10**length)
|
||||
return str(num).zfill(length)
|
||||
|
||||
|
||||
def generate_id(length=40) -> str:
|
||||
def generate_id(length: int = 40) -> str:
|
||||
"""Generate a random client ID"""
|
||||
rand = SystemRandom()
|
||||
return "".join(rand.choice(string.ascii_letters + string.digits) for x in range(length))
|
||||
|
||||
|
||||
def generate_key(length=128) -> str:
|
||||
def generate_key(length: int = 128) -> str:
|
||||
"""Generate a suitable client secret"""
|
||||
rand = SystemRandom()
|
||||
return "".join(
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
import logging
|
||||
from logging import Logger
|
||||
from os import getpid
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from django.db import connection
|
||||
from structlog.typing import EventDict
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
@@ -19,9 +21,9 @@ LOG_PRE_CHAIN = [
|
||||
]
|
||||
|
||||
|
||||
def get_log_level():
|
||||
def get_log_level() -> str:
|
||||
"""Get log level, clamp trace to debug"""
|
||||
level = CONFIG.get("log_level").upper()
|
||||
level: str = CONFIG.get("log_level").upper()
|
||||
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
|
||||
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
|
||||
# Additionally, the entire code uses debug as highest level
|
||||
@@ -31,7 +33,7 @@ def get_log_level():
|
||||
return level
|
||||
|
||||
|
||||
def structlog_configure():
|
||||
def structlog_configure() -> None:
|
||||
"""Configure structlog itself"""
|
||||
structlog.configure_once(
|
||||
processors=[
|
||||
@@ -56,11 +58,11 @@ def structlog_configure():
|
||||
)
|
||||
|
||||
|
||||
def get_logger_config():
|
||||
def get_logger_config() -> dict[str, Any]:
|
||||
"""Configure python stdlib's logging"""
|
||||
debug = CONFIG.get_bool("debug")
|
||||
global_level = get_log_level()
|
||||
base_config = {
|
||||
base_config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
@@ -121,13 +123,13 @@ def get_logger_config():
|
||||
return base_config
|
||||
|
||||
|
||||
def add_process_id(logger: Logger, method_name: str, event_dict):
|
||||
def add_process_id(logger: Logger, method_name: str, event_dict: EventDict) -> EventDict:
|
||||
"""Add the current process ID"""
|
||||
event_dict["pid"] = getpid()
|
||||
return event_dict
|
||||
|
||||
|
||||
def add_tenant_information(logger: Logger, method_name: str, event_dict):
|
||||
def add_tenant_information(logger: Logger, method_name: str, event_dict: EventDict) -> EventDict:
|
||||
"""Add the current tenant"""
|
||||
tenant = getattr(connection, "tenant", None)
|
||||
schema_name = getattr(connection, "schema_name", None)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""merge utils"""
|
||||
|
||||
from deepmerge import Merger
|
||||
from deepmerge import Merger # type: ignore[attr-defined]
|
||||
|
||||
MERGE_LIST_UNIQUE = Merger(
|
||||
[(list, ["append_unique"]), (dict, ["merge"]), (set, ["union"])], ["override"], ["override"]
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""Migration helpers"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Collection, Generator
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def fallback_names(app: str, model: str, field: str):
|
||||
def fallback_names(
|
||||
app: str, model: str, field: str
|
||||
) -> Callable[[Apps, BaseDatabaseSchemaEditor], None]:
|
||||
"""Factory function that checks all instances of `app`.`model` instance's `field`
|
||||
to prevent any duplicates"""
|
||||
|
||||
def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor) -> None:
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
klass = apps.get_model(app, model)
|
||||
@@ -35,7 +37,7 @@ def fallback_names(app: str, model: str, field: str):
|
||||
return migrator
|
||||
|
||||
|
||||
def progress_bar(iterable: Iterable):
|
||||
def progress_bar[R](iterable: Collection[R]) -> Generator[R]:
|
||||
"""Call in a loop to create terminal progress bar
|
||||
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console"""
|
||||
|
||||
@@ -50,7 +52,7 @@ def progress_bar(iterable: Iterable):
|
||||
if total < 1:
|
||||
return
|
||||
|
||||
def print_progress_bar(iteration):
|
||||
def print_progress_bar(iteration: int) -> None:
|
||||
"""Progress Bar Printing Function"""
|
||||
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
|
||||
filled_length = int(length * iteration // total)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""authentik sentry integration"""
|
||||
|
||||
from asyncio.exceptions import CancelledError
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
|
||||
@@ -34,6 +34,9 @@ from authentik.lib.utils.reflection import get_env
|
||||
LOGGER = get_logger()
|
||||
_root_path = CONFIG.get("web.path", "/")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentry_sdk._types import Event
|
||||
|
||||
|
||||
class SentryIgnoredException(Exception):
|
||||
"""Base Class for all errors that are suppressed, and not sent to sentry."""
|
||||
@@ -79,10 +82,11 @@ class SentryTransport(HttpTransport):
|
||||
|
||||
def __init__(self, options: dict[str, Any]) -> None:
|
||||
super().__init__(options)
|
||||
assert self.parsed_dsn is not None # nosec
|
||||
self._auth = self.parsed_dsn.to_auth(authentik_user_agent())
|
||||
|
||||
|
||||
def sentry_init(**sentry_init_kwargs):
|
||||
def sentry_init(**sentry_init_kwargs: Any) -> None:
|
||||
"""Configure sentry SDK"""
|
||||
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||
kwargs = {
|
||||
@@ -116,7 +120,7 @@ def sentry_init(**sentry_init_kwargs):
|
||||
set_tag("authentik.component", "backend")
|
||||
|
||||
|
||||
def traces_sampler(sampling_context: dict) -> float:
|
||||
def traces_sampler(sampling_context: dict[str, Any]) -> float:
|
||||
"""Custom sampler to ignore certain routes"""
|
||||
path = sampling_context.get("asgi_scope", {}).get("path", "")
|
||||
_type = sampling_context.get("asgi_scope", {}).get("type", "")
|
||||
@@ -135,7 +139,7 @@ def should_ignore_exception(exc: Exception) -> bool:
|
||||
return isinstance(exc, ignored_classes)
|
||||
|
||||
|
||||
def before_send(event: dict, hint: dict) -> dict | None:
|
||||
def before_send(event: "Event", hint: dict[str, Any]) -> "Event | None":
|
||||
"""Check if error is database error, and ignore if so"""
|
||||
exc_value = None
|
||||
if "exc_info" in hint:
|
||||
@@ -157,7 +161,7 @@ def before_send(event: dict, hint: dict) -> dict | None:
|
||||
return event
|
||||
|
||||
|
||||
def get_http_meta():
|
||||
def get_http_meta() -> dict[str, Any]:
|
||||
"""Get sentry-related meta key-values"""
|
||||
scope = get_current_scope()
|
||||
meta = {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpRequest
|
||||
@@ -20,7 +21,7 @@ class PropertyMappingManager:
|
||||
|
||||
_evaluators: list[PropertyMappingEvaluator]
|
||||
|
||||
globals: dict
|
||||
globals: dict[str, Any]
|
||||
|
||||
__has_compiled: bool
|
||||
|
||||
@@ -40,7 +41,7 @@ class PropertyMappingManager:
|
||||
self.globals = {}
|
||||
self.__has_compiled = False
|
||||
|
||||
def compile(self):
|
||||
def compile(self) -> None:
|
||||
self._evaluators = []
|
||||
for mapping in self.query_set:
|
||||
if not isinstance(mapping, self.mapping_subclass):
|
||||
@@ -58,8 +59,8 @@ class PropertyMappingManager:
|
||||
user: User | None,
|
||||
request: HttpRequest | None,
|
||||
return_mapping: bool = False,
|
||||
**kwargs,
|
||||
) -> Generator[tuple[dict, PropertyMapping]]:
|
||||
**kwargs: Any,
|
||||
) -> Generator[tuple[Any, PropertyMapping]]:
|
||||
"""Iterate over all mappings that were pre-compiled and
|
||||
execute all of them with the given context"""
|
||||
if not self.__has_compiled:
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Model
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.results.errors import ResultFailure
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField
|
||||
from rest_framework.mixins import (
|
||||
CreateModelMixin,
|
||||
DestroyModelMixin,
|
||||
ListModelMixin,
|
||||
RetrieveModelMixin,
|
||||
)
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.logs import LogEventSerializer
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.sync.api import SyncStatusSerializer
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
@@ -36,11 +47,19 @@ class SyncObjectResultSerializer(PassiveSerializer):
|
||||
messages = LogEventSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class OutgoingSyncProviderStatusMixin:
|
||||
class OutgoingSyncProviderViewSet(UsedByMixin, ModelViewSet[OutgoingSyncProvider]):
|
||||
"""Common API Endpoints for Outgoing sync providers"""
|
||||
|
||||
sync_task: Actor
|
||||
sync_objects_task: Actor
|
||||
sync_task: Actor[[int, Actor[[str, int, int, bool], None]], None]
|
||||
sync_objects_task: Actor[[str, int, int, bool, dict[str, Any | None]], None]
|
||||
|
||||
filterset_fields = [
|
||||
"name",
|
||||
"exclude_users_service_account",
|
||||
"filter_group",
|
||||
]
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
|
||||
@extend_schema(responses={200: SyncStatusSerializer()})
|
||||
@action(
|
||||
@@ -68,20 +87,20 @@ class OutgoingSyncProviderStatusMixin:
|
||||
if not sync_schedule:
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
||||
last_task: Task = (
|
||||
last_task = (
|
||||
sync_schedule.tasks.filter(state__in=(TaskStatus.DONE, TaskStatus.REJECTED))
|
||||
.order_by("-mtime")
|
||||
.first()
|
||||
)
|
||||
last_successful_task: Task = (
|
||||
last_successful_task = (
|
||||
sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO))
|
||||
.order_by("-mtime")
|
||||
.first()
|
||||
)
|
||||
|
||||
if last_task:
|
||||
if last_task is not None:
|
||||
status["last_sync_status"] = last_task.aggregated_status
|
||||
if last_successful_task:
|
||||
if last_successful_task is not None:
|
||||
status["last_successful_sync"] = last_successful_task.mtime
|
||||
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
@@ -111,7 +130,7 @@ class OutgoingSyncProviderStatusMixin:
|
||||
"page": 1,
|
||||
"provider_pk": provider.pk,
|
||||
"override_dry_run": params.validated_data["override_dry_run"],
|
||||
"pk": pk,
|
||||
"filter": {"pk": pk},
|
||||
},
|
||||
retries=0,
|
||||
rel_obj=provider,
|
||||
@@ -126,13 +145,20 @@ class OutgoingSyncProviderStatusMixin:
|
||||
return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data)
|
||||
|
||||
|
||||
class OutgoingSyncConnectionCreateMixin:
|
||||
"""Mixin for connection objects that fetches remote data upon creation"""
|
||||
|
||||
def perform_create(self, serializer: ModelSerializer):
|
||||
class OutgoingSyncConnectionViewSet(
|
||||
CreateModelMixin,
|
||||
RetrieveModelMixin,
|
||||
DestroyModelMixin,
|
||||
ListModelMixin,
|
||||
UsedByMixin,
|
||||
GenericViewSet[SerializerModel],
|
||||
):
|
||||
def perform_create(self, serializer: ModelSerializer) -> None: # type: ignore[override]
|
||||
super().perform_create(serializer)
|
||||
try:
|
||||
instance = serializer.instance
|
||||
if instance is None:
|
||||
return
|
||||
client = instance.provider.client_for_model(instance.__class__)
|
||||
client.update_single_attribute(instance)
|
||||
instance.save()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Basic outgoing sync Client"""
|
||||
|
||||
from collections.abc import MutableMapping
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from deepmerge import always_merger
|
||||
from django.db import DatabaseError
|
||||
@@ -18,11 +19,11 @@ from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSy
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
|
||||
|
||||
class Direction(StrEnum):
|
||||
|
||||
add = "add"
|
||||
remove = "remove"
|
||||
|
||||
@@ -36,7 +37,10 @@ SAFE_METHODS = [
|
||||
|
||||
|
||||
class BaseOutgoingSyncClient[
|
||||
TModel: "Model", TConnection: "Model", TSchema: dict, TProvider: "OutgoingSyncProvider"
|
||||
TModel: "User | Group",
|
||||
TConnection: "Model",
|
||||
TSchema: MutableMapping[Any, Any],
|
||||
TProvider: "OutgoingSyncProvider",
|
||||
]:
|
||||
"""Basic Outgoing sync client Client"""
|
||||
|
||||
@@ -55,14 +59,17 @@ class BaseOutgoingSyncClient[
|
||||
"""Create object in remote destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, obj: TModel, connection: TConnection):
|
||||
def update(self, obj: TModel, connection: TConnection) -> None:
|
||||
"""Update object in remote destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, obj: TModel) -> tuple[TConnection, bool]:
|
||||
def update_group(self, group: "Group", action: Direction, users_set: list[Any]) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, obj: TModel) -> tuple[TConnection | None, bool]:
|
||||
"""Write object to destination. Uses self.create and self.update, but
|
||||
can be overwritten for further logic"""
|
||||
connection = self.connection_type.objects.filter(
|
||||
connection = self.connection_type.objects.filter( # type: ignore[attr-defined]
|
||||
provider=self.provider, **{self.connection_type_query: obj}
|
||||
).first()
|
||||
try:
|
||||
@@ -82,13 +89,13 @@ class BaseOutgoingSyncClient[
|
||||
connection.delete()
|
||||
return None, False
|
||||
|
||||
def delete(self, obj: TModel):
|
||||
def delete(self, obj: TModel) -> None:
|
||||
"""Delete object from destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults) -> TSchema:
|
||||
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults: Any) -> TSchema:
|
||||
"""Convert object to destination schema"""
|
||||
raw_final_object = {}
|
||||
raw_final_object: dict[Any, Any] = {}
|
||||
try:
|
||||
eval_kwargs = {
|
||||
"request": None,
|
||||
@@ -97,7 +104,7 @@ class BaseOutgoingSyncClient[
|
||||
obj._meta.model_name: obj,
|
||||
}
|
||||
eval_kwargs.setdefault("user", None)
|
||||
for value in self.mapper.iter_eval(**eval_kwargs):
|
||||
for value in self.mapper.iter_eval(**eval_kwargs): # type: ignore[arg-type, misc]
|
||||
always_merger.merge(raw_final_object, value)
|
||||
except ControlFlowException as exc:
|
||||
raise exc from exc
|
||||
@@ -113,16 +120,16 @@ class BaseOutgoingSyncClient[
|
||||
raise StopSync(ValueError("No mappings configured"), obj)
|
||||
for key, value in defaults.items():
|
||||
raw_final_object.setdefault(key, value)
|
||||
return raw_final_object
|
||||
return cast(TSchema, raw_final_object)
|
||||
|
||||
def discover(self):
|
||||
def discover(self) -> None:
|
||||
"""Optional method. Can be used to implement a "discovery" where
|
||||
upon creation of this provider, this function will be called and can
|
||||
pre-link any users/groups in the remote system with the respective
|
||||
object in authentik based on a common identifier"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_attribute(self, connection: TConnection):
|
||||
def update_single_attribute(self, connection: TConnection) -> None:
|
||||
"""Update connection attributes on a connection object, when the connection
|
||||
is manually created"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
|
||||
|
||||
@@ -24,16 +26,16 @@ class BadRequestSyncException(BaseSyncException):
|
||||
class DryRunRejected(BaseSyncException):
|
||||
"""When dry_run is enabled and a provider dropped a mutating request"""
|
||||
|
||||
def __init__(self, url: str, method: str, body: dict):
|
||||
def __init__(self, url: str, method: str, body: dict[Any, Any]) -> None:
|
||||
super().__init__()
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.body = body
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"Dry-run rejected request: {self.method} {self.url}"
|
||||
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import pglock
|
||||
from django.core.paginator import Paginator
|
||||
from django.core.validators import MinValueValidator
|
||||
from django.db import connection, models
|
||||
from django.db.models import Model, QuerySet, TextChoices
|
||||
from django.db.models import QuerySet, TextChoices
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.models import BackchannelProvider, Group, User
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.utils.time import fqdn_rand, timedelta_from_string, timedelta_string_validator
|
||||
from authentik.tasks.schedules.common import ScheduleSpec
|
||||
@@ -24,7 +24,7 @@ class OutgoingSyncDeleteAction(TextChoices):
|
||||
SUSPEND = "suspend"
|
||||
|
||||
|
||||
class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
class OutgoingSyncProvider(ScheduledModel, BackchannelProvider):
|
||||
"""Base abstract models for providers implementing outgoing sync"""
|
||||
|
||||
sync_page_size = models.PositiveIntegerField(
|
||||
@@ -56,7 +56,7 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
|
||||
def get_paginator[T: User | Group](self, type: type[T]) -> "Paginator[T]":
|
||||
return Paginator(self.get_object_qs(type), self.sync_page_size)
|
||||
|
||||
def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
|
||||
@@ -74,13 +74,15 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
def sync_lock(self) -> pglock.advisory:
|
||||
"""Postgres lock for syncing to prevent multiple parallel syncs happening"""
|
||||
return pglock.advisory(
|
||||
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}",
|
||||
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}", # type: ignore[attr-defined]
|
||||
timeout=0,
|
||||
side_effect=pglock.Return,
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_actor(self) -> Actor:
|
||||
def sync_actor(
|
||||
self,
|
||||
) -> Actor[[int, Actor[[str, int, int, bool, dict[str, Any] | None], None]], None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -94,6 +96,6 @@ class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
"time_limit": self.get_sync_time_limit_ms(),
|
||||
},
|
||||
send_on_save=True,
|
||||
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
|
||||
crontab=f"{fqdn_rand(str(self.pk))} */4 * * *",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from django.db.models import Model
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
|
||||
from dramatiq.actor import Actor
|
||||
|
||||
@@ -7,22 +10,24 @@ from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=User | Group)
|
||||
|
||||
|
||||
def register_signals(
|
||||
provider_type: type[OutgoingSyncProvider],
|
||||
task_sync_direct_dispatch: Actor[[str, str | int, str], None],
|
||||
task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
|
||||
):
|
||||
task_sync_direct_dispatch: Actor[[str, Any, str], None],
|
||||
task_sync_m2m_dispatch: Actor[[Any, str, list[Any], bool], None],
|
||||
) -> None:
|
||||
"""Register sync signals"""
|
||||
uid = class_to_path(provider_type)
|
||||
|
||||
def model_post_save(
|
||||
sender: type[Model],
|
||||
instance: User | Group,
|
||||
sender: type[ModelT],
|
||||
instance: ModelT,
|
||||
created: bool,
|
||||
update_fields: list[str] | None = None,
|
||||
**_,
|
||||
):
|
||||
update_fields: Iterable[str] | None = None,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
"""Post save handler"""
|
||||
# Special case for user object; don't start sync task when we've only updated `last_login`
|
||||
# This primarily happens during user login
|
||||
@@ -37,7 +42,7 @@ def register_signals(
|
||||
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
||||
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
||||
def model_pre_delete(sender: type[ModelT], instance: ModelT, **_: Any) -> None:
|
||||
"""Pre-delete handler"""
|
||||
task_sync_direct_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
@@ -49,8 +54,13 @@ def register_signals(
|
||||
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_m2m_changed(
|
||||
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
|
||||
):
|
||||
sender: type[ModelT],
|
||||
instance: ModelT,
|
||||
action: str,
|
||||
pk_set: set[int | UUID],
|
||||
reverse: bool,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
"""Sync group membership"""
|
||||
if action not in ["post_add", "post_remove"]:
|
||||
return
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.db.models.query import Q
|
||||
from django.db.models import Model, Q
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.composition import group
|
||||
from dramatiq.errors import Retry
|
||||
from dramatiq.message import Message
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
@@ -38,11 +40,11 @@ class SyncTasks:
|
||||
self,
|
||||
current_task: Task,
|
||||
provider: OutgoingSyncProvider,
|
||||
sync_objects: Actor[[str, int, int, bool], None],
|
||||
paginator: Paginator,
|
||||
sync_objects: Actor[[str, int, int, bool, dict[str, Any] | None], None],
|
||||
paginator: "Paginator[User | Group]",
|
||||
object_type: type[User | Group],
|
||||
**options,
|
||||
):
|
||||
**options: Any,
|
||||
) -> list[Message[None]]:
|
||||
tasks = []
|
||||
time_limit = timedelta_from_string(provider.sync_page_timeout).total_seconds() * 1000
|
||||
for page in paginator.page_range:
|
||||
@@ -60,14 +62,14 @@ class SyncTasks:
|
||||
def sync(
|
||||
self,
|
||||
provider_pk: int,
|
||||
sync_objects: Actor[[str, int, int, bool], None],
|
||||
):
|
||||
sync_objects: Actor[[str, int, int, bool, dict[str, Any] | None], None],
|
||||
) -> None:
|
||||
task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
)
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
provider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
@@ -82,7 +84,7 @@ class SyncTasks:
|
||||
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
||||
return
|
||||
try:
|
||||
users_tasks = group(
|
||||
users_tasks = group( # type: ignore[no-untyped-call]
|
||||
self.sync_paginator(
|
||||
current_task=task,
|
||||
provider=provider,
|
||||
@@ -91,7 +93,7 @@ class SyncTasks:
|
||||
object_type=User,
|
||||
)
|
||||
)
|
||||
group_tasks = group(
|
||||
group_tasks = group( # type: ignore[no-untyped-call]
|
||||
self.sync_paginator(
|
||||
current_task=task,
|
||||
provider=provider,
|
||||
@@ -100,12 +102,12 @@ class SyncTasks:
|
||||
object_type=Group,
|
||||
)
|
||||
)
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) # type: ignore[no-untyped-call]
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) # type: ignore[no-untyped-call]
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
|
||||
raise Retry() from exc
|
||||
raise Retry() from exc # type: ignore[no-untyped-call]
|
||||
except StopSync as exc:
|
||||
task.error(exc)
|
||||
return
|
||||
@@ -115,11 +117,11 @@ class SyncTasks:
|
||||
object_type: str,
|
||||
page: int,
|
||||
provider_pk: int,
|
||||
override_dry_run=False,
|
||||
**filter,
|
||||
):
|
||||
override_dry_run: bool = False,
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
task = CurrentTask.get_task()
|
||||
_object_type: type[Model] = path_to_class(object_type)
|
||||
_object_type: type[User | Group] = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
@@ -140,6 +142,8 @@ class SyncTasks:
|
||||
client = provider.client_for_model(_object_type)
|
||||
except TransientSyncException:
|
||||
return
|
||||
if filter is None:
|
||||
filter = {}
|
||||
paginator = Paginator(
|
||||
provider.get_object_qs(_object_type).filter(**filter),
|
||||
provider.sync_page_size,
|
||||
@@ -150,7 +154,6 @@ class SyncTasks:
|
||||
self.logger.debug("starting sync for page", page=page)
|
||||
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
|
||||
for obj in paginator.page(page).object_list:
|
||||
obj: Model
|
||||
try:
|
||||
client.write(obj)
|
||||
except SkipObjectException:
|
||||
@@ -189,11 +192,11 @@ class SyncTasks:
|
||||
|
||||
def sync_signal_direct_dispatch(
|
||||
self,
|
||||
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
|
||||
task_sync_signal_direct: Actor[[str, Any, int, str], None],
|
||||
model: str,
|
||||
pk: str | int,
|
||||
pk: Any,
|
||||
raw_op: str,
|
||||
):
|
||||
) -> None:
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
@@ -207,19 +210,19 @@ class SyncTasks:
|
||||
def sync_signal_direct(
|
||||
self,
|
||||
model: str,
|
||||
pk: str | int,
|
||||
pk: Any,
|
||||
provider_pk: int,
|
||||
raw_op: str,
|
||||
):
|
||||
) -> None:
|
||||
task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
model_class: type[User | Group] = path_to_class(model)
|
||||
instance = model_class.objects.filter(pk=pk).first()
|
||||
if not instance:
|
||||
return
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
provider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
@@ -244,7 +247,7 @@ class SyncTasks:
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
raise Retry() from exc # type: ignore[no-untyped-call]
|
||||
except SkipObjectException:
|
||||
return
|
||||
except DryRunRejected as exc:
|
||||
@@ -254,12 +257,12 @@ class SyncTasks:
|
||||
|
||||
def sync_signal_m2m_dispatch(
|
||||
self,
|
||||
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
|
||||
instance_pk: str,
|
||||
task_sync_signal_m2m: Actor[[Any, int, str, list[Any]], None],
|
||||
instance_pk: Any,
|
||||
action: str,
|
||||
pk_set: list[int],
|
||||
pk_set: list[Any],
|
||||
reverse: bool,
|
||||
):
|
||||
) -> None:
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
@@ -281,11 +284,11 @@ class SyncTasks:
|
||||
|
||||
def sync_signal_m2m(
|
||||
self,
|
||||
group_pk: str,
|
||||
group_pk: Any,
|
||||
provider_pk: int,
|
||||
action: str,
|
||||
pk_set: list[int],
|
||||
):
|
||||
pk_set: list[Any],
|
||||
) -> None:
|
||||
task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
@@ -293,7 +296,7 @@ class SyncTasks:
|
||||
group = Group.objects.filter(pk=group_pk).first()
|
||||
if not group:
|
||||
return
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
provider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
@@ -302,7 +305,7 @@ class SyncTasks:
|
||||
return
|
||||
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_object_qs(Group)
|
||||
queryset = provider.get_object_qs(Group)
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
@@ -315,9 +318,9 @@ class SyncTasks:
|
||||
operation = Direction.add
|
||||
if action == "post_remove":
|
||||
operation = Direction.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
client.update_group(group, cast(Direction, operation), pk_set)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
raise Retry() from exc # type: ignore[no-untyped-call]
|
||||
except SkipObjectException:
|
||||
return
|
||||
except DryRunRejected as exc:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test Evaluator base functions"""
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import NonCallableMock, patch
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.urls import reverse
|
||||
@@ -17,27 +17,27 @@ from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
|
||||
class TestEvaluator(TestCase):
|
||||
"""Test Evaluator base functions"""
|
||||
|
||||
def test_expr_regex_match(self):
|
||||
def test_expr_regex_match(self) -> None:
|
||||
"""Test expr_regex_match"""
|
||||
self.assertFalse(BaseEvaluator.expr_regex_match("foo", "bar"))
|
||||
self.assertTrue(BaseEvaluator.expr_regex_match("foo", "foo"))
|
||||
|
||||
def test_expr_regex_replace(self):
|
||||
def test_expr_regex_replace(self) -> None:
|
||||
"""Test expr_regex_replace"""
|
||||
self.assertEqual(BaseEvaluator.expr_regex_replace("foo", "o", "a"), "faa")
|
||||
|
||||
def test_expr_user_by(self):
|
||||
def test_expr_user_by(self) -> None:
|
||||
"""Test expr_user_by"""
|
||||
user = create_test_admin_user()
|
||||
self.assertIsNotNone(BaseEvaluator.expr_user_by(username=user.username))
|
||||
self.assertIsNone(BaseEvaluator.expr_user_by(username="bar"))
|
||||
self.assertIsNone(BaseEvaluator.expr_user_by(foo="bar"))
|
||||
|
||||
def test_expr_is_group_member(self):
|
||||
def test_expr_is_group_member(self) -> None:
|
||||
"""Test expr_is_group_member"""
|
||||
self.assertFalse(BaseEvaluator.expr_is_group_member(create_test_admin_user(), name="test"))
|
||||
|
||||
def test_expr_event_create(self):
|
||||
def test_expr_event_create(self) -> None:
|
||||
"""Test expr_event_create"""
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
evaluator._context = {
|
||||
@@ -46,10 +46,11 @@ class TestEvaluator(TestCase):
|
||||
evaluator.evaluate("ak_create_event('foo', bar='baz')")
|
||||
event = Event.objects.filter(action="custom_foo").first()
|
||||
self.assertIsNotNone(event)
|
||||
assert event is not None # nosec
|
||||
self.assertEqual(event.context, {"bar": "baz", "foo": "bar"})
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_expr_create_jwt(self):
|
||||
def test_expr_create_jwt(self) -> None:
|
||||
"""Test expr_create_jwt"""
|
||||
rf = RequestFactory()
|
||||
user = create_test_user()
|
||||
@@ -81,7 +82,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertEqual(decoded["preferred_username"], user.username)
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_body(self, mock_send_mails):
|
||||
def test_expr_send_email_with_body(self, mock_send_mails: NonCallableMock) -> None:
|
||||
"""Test ak_send_email with body parameter"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
@@ -108,7 +109,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertEqual(message.body, "Test Body")
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_template(self, mock_send_mails):
|
||||
def test_expr_send_email_with_template(self, mock_send_mails: NonCallableMock) -> None:
|
||||
"""Test ak_send_email with template parameter"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
@@ -123,7 +124,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertTrue(result)
|
||||
mock_send_mails.assert_called_once()
|
||||
|
||||
def test_expr_send_email_validation_errors(self):
|
||||
def test_expr_send_email_validation_errors(self) -> None:
|
||||
"""Test ak_send_email validation errors"""
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
|
||||
@@ -141,7 +142,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertIn("Either body or template parameter must be provided", str(cm.exception))
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_custom_stage(self, mock_send_mails):
|
||||
def test_expr_send_email_with_custom_stage(self, mock_send_mails: NonCallableMock) -> None:
|
||||
"""Test ak_send_email with custom EmailStage"""
|
||||
from authentik.stages.email.models import EmailStage
|
||||
|
||||
@@ -170,7 +171,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertFalse(stage.use_global_settings)
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_with_context(self, mock_send_mails):
|
||||
def test_expr_send_email_with_context(self, mock_send_mails: NonCallableMock) -> None:
|
||||
"""Test ak_send_email with custom context parameter"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
@@ -199,7 +200,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertIn("http://localhost", message.body)
|
||||
|
||||
@patch("authentik.stages.email.tasks.send_mails")
|
||||
def test_expr_send_email_multiple_addresses(self, mock_send_mails):
|
||||
def test_expr_send_email_multiple_addresses(self, mock_send_mails: NonCallableMock) -> None:
|
||||
"""Test ak_send_email with multiple email addresses"""
|
||||
user = create_test_user()
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
@@ -226,7 +227,7 @@ class TestEvaluator(TestCase):
|
||||
self.assertEqual(message.to, ["user1@example.com", "user2@example.com"])
|
||||
self.assertEqual(message.body, "Test Body")
|
||||
|
||||
def test_expr_send_email_multiple_addresses_validation(self):
|
||||
def test_expr_send_email_multiple_addresses_validation(self) -> None:
|
||||
"""Test ak_send_email validation with multiple addresses"""
|
||||
evaluator = BaseEvaluator(generate_id())
|
||||
|
||||
|
||||
@@ -15,27 +15,27 @@ class TestHTTP(TestCase):
|
||||
self.user = create_test_admin_user()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_bad_request_message(self):
|
||||
def test_bad_request_message(self) -> None:
|
||||
"""test bad_request_message"""
|
||||
request = self.factory.get("/")
|
||||
self.assertEqual(bad_request_message(request, "foo").status_code, 400)
|
||||
|
||||
def test_normal(self):
|
||||
def test_normal(self) -> None:
|
||||
"""Test normal request"""
|
||||
request = self.factory.get("/")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||
|
||||
def test_forward_for(self):
|
||||
def test_forward_for(self) -> None:
|
||||
"""Test x-forwarded-for request"""
|
||||
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
|
||||
|
||||
def test_forward_for_invalid(self):
|
||||
def test_forward_for_invalid(self) -> None:
|
||||
"""Test invalid forward for"""
|
||||
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
|
||||
|
||||
def test_fake_outpost(self):
|
||||
def test_fake_outpost(self) -> None:
|
||||
"""Test faked IP which is overridden by an outpost"""
|
||||
token = Token.objects.create(
|
||||
identifier="test", user=self.user, intent=TokenIntents.INTENT_API
|
||||
@@ -43,7 +43,7 @@ class TestHTTP(TestCase):
|
||||
# Invalid, non-existent token
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||
ClientIPMiddleware.outpost_token_header: "abc",
|
||||
},
|
||||
@@ -52,7 +52,7 @@ class TestHTTP(TestCase):
|
||||
# Invalid, user doesn't have permissions
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||
ClientIPMiddleware.outpost_token_header: token.key,
|
||||
},
|
||||
@@ -63,7 +63,7 @@ class TestHTTP(TestCase):
|
||||
self.user.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "foobar",
|
||||
ClientIPMiddleware.outpost_token_header: token.key,
|
||||
},
|
||||
@@ -74,7 +74,7 @@ class TestHTTP(TestCase):
|
||||
self.user.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||
ClientIPMiddleware.outpost_token_header: token.key,
|
||||
},
|
||||
|
||||
@@ -8,10 +8,10 @@ from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
|
||||
class TestSentry(TestCase):
|
||||
"""test sentry integration"""
|
||||
|
||||
def test_error_not_sent(self):
|
||||
def test_error_not_sent(self) -> None:
|
||||
"""Test SentryIgnoredError not sent"""
|
||||
self.assertTrue(should_ignore_exception(SentryIgnoredException()))
|
||||
|
||||
def test_error_sent(self):
|
||||
def test_error_sent(self) -> None:
|
||||
"""Test error sent"""
|
||||
self.assertFalse(should_ignore_exception(ValueError()))
|
||||
|
||||
@@ -5,7 +5,6 @@ from collections.abc import Callable
|
||||
from django.test import TestCase
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
|
||||
@@ -14,10 +13,10 @@ class TestModels(TestCase):
|
||||
"""Generic model properties tests"""
|
||||
|
||||
|
||||
def model_tester_factory(test_model: type[Stage]) -> Callable:
|
||||
def model_tester_factory(test_model: type[SerializerModel]) -> Callable[[TestModels], None]:
|
||||
"""Test a form"""
|
||||
|
||||
def tester(self: TestModels):
|
||||
def tester(self: TestModels) -> None:
|
||||
try:
|
||||
model_class = None
|
||||
if test_model._meta.abstract: # pragma: no cover
|
||||
@@ -31,4 +30,4 @@ def model_tester_factory(test_model: type[Stage]) -> Callable:
|
||||
|
||||
|
||||
for model in all_subclasses(SerializerModel):
|
||||
setattr(TestModels, f"test_model_{model.__name__}", model_tester_factory(model))
|
||||
setattr(TestModels, f"test_model_{model.__name__}", model_tester_factory(model)) # type: ignore[type-abstract]
|
||||
|
||||
@@ -10,6 +10,6 @@ from authentik.lib.utils.reflection import path_to_class
|
||||
class TestReflectionUtils(TestCase):
|
||||
"""Test Reflection-utils"""
|
||||
|
||||
def test_path_to_class(self):
|
||||
def test_path_to_class(self) -> None:
|
||||
"""Test path_to_class"""
|
||||
self.assertEqual(path_to_class("datetime.datetime"), datetime)
|
||||
|
||||
@@ -11,20 +11,20 @@ from authentik.lib.utils.time import timedelta_from_string, timedelta_string_val
|
||||
class TestTimeUtils(TestCase):
|
||||
"""Test time-utils"""
|
||||
|
||||
def test_valid(self):
|
||||
def test_valid(self) -> None:
|
||||
"""Test valid expression"""
|
||||
expr = "hours=3;minutes=1"
|
||||
expected = timedelta(hours=3, minutes=1)
|
||||
self.assertEqual(timedelta_from_string(expr), expected)
|
||||
|
||||
def test_invalid(self):
|
||||
def test_invalid(self) -> None:
|
||||
"""Test invalid expression"""
|
||||
with self.assertRaises(ValueError):
|
||||
timedelta_from_string("foo")
|
||||
with self.assertRaises(ValueError):
|
||||
timedelta_from_string("bar=baz")
|
||||
|
||||
def test_validation(self):
|
||||
def test_validation(self) -> None:
|
||||
"""Test Django model field validator"""
|
||||
with self.assertRaises(ValidationError):
|
||||
timedelta_string_validator("foo")
|
||||
|
||||
@@ -2,23 +2,31 @@
|
||||
|
||||
from inspect import currentframe
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.messages.middleware import MessageMiddleware
|
||||
from django.contrib.sessions.middleware import SessionMiddleware
|
||||
from django.http import HttpRequest
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.test.client import RequestFactory
|
||||
from guardian.utils import get_anonymous_user
|
||||
|
||||
from authentik.core.models import User
|
||||
|
||||
def dummy_get_response(request: HttpRequest): # pragma: no cover
|
||||
|
||||
def dummy_get_response(request: HttpRequest) -> HttpResponse: # pragma: no cover
|
||||
"""Dummy get_response for SessionMiddleware"""
|
||||
return None
|
||||
return HttpResponse()
|
||||
|
||||
|
||||
def load_fixture(path: str, **kwargs) -> str:
|
||||
def load_fixture(path: str, **kwargs: Any) -> str:
|
||||
"""Load fixture, optionally formatting it with kwargs"""
|
||||
current = currentframe()
|
||||
if current is None:
|
||||
return ""
|
||||
parent = current.f_back
|
||||
if parent is None:
|
||||
return ""
|
||||
calling_file_path = parent.f_globals["__file__"]
|
||||
with open(Path(calling_file_path).resolve().parent / Path(path), encoding="utf-8") as _fixture:
|
||||
fixture = _fixture.read()
|
||||
@@ -28,17 +36,17 @@ def load_fixture(path: str, **kwargs) -> str:
|
||||
return fixture
|
||||
|
||||
|
||||
def get_request(*args, user=None, **kwargs):
|
||||
def get_request(*args: Any, user: User | None = None, **kwargs: Any) -> WSGIRequest:
|
||||
"""Get a request with usable session"""
|
||||
request = RequestFactory().get(*args, **kwargs)
|
||||
if user:
|
||||
if user is not None:
|
||||
request.user = user
|
||||
else:
|
||||
request.user = get_anonymous_user()
|
||||
middleware = SessionMiddleware(dummy_get_response)
|
||||
middleware.process_request(request)
|
||||
session_middleware = SessionMiddleware(dummy_get_response)
|
||||
session_middleware.process_request(request)
|
||||
request.session.save()
|
||||
middleware = MessageMiddleware(dummy_get_response)
|
||||
middleware.process_request(request)
|
||||
message_middleware = MessageMiddleware(dummy_get_response)
|
||||
message_middleware.process_request(request)
|
||||
request.session.save()
|
||||
return request
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
"""authentik database utilities"""
|
||||
|
||||
import gc
|
||||
from collections.abc import Generator
|
||||
from typing import TypeVar
|
||||
|
||||
from django.db import reset_queries
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Model, QuerySet
|
||||
|
||||
ModelT_co = TypeVar("ModelT_co", bound=Model, covariant=True)
|
||||
|
||||
|
||||
def chunked_queryset(queryset: QuerySet, chunk_size: int = 1_000):
|
||||
def chunked_queryset(
|
||||
queryset: QuerySet[ModelT_co], chunk_size: int = 1_000
|
||||
) -> Generator[ModelT_co]:
|
||||
if not queryset.exists():
|
||||
return []
|
||||
return
|
||||
|
||||
def get_chunks(qs: QuerySet):
|
||||
def get_chunks(qs: QuerySet[ModelT_co]) -> Generator[QuerySet[ModelT_co]]:
|
||||
qs = qs.order_by("pk")
|
||||
pks = qs.values_list("pk", flat=True)
|
||||
start_pk = pks[0]
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
type rdict[R] = dict[str, "rdict[R] | R"]
|
||||
|
||||
|
||||
def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
|
||||
def get_path_from_dict[R: Any](
|
||||
root: rdict[R],
|
||||
path: str,
|
||||
sep: str = ".",
|
||||
default: R | None = None,
|
||||
) -> Any | None:
|
||||
"""Recursively walk through `root`, checking each part of `path` separated by `sep`.
|
||||
If at any point a dict does not exist, return default"""
|
||||
walk: Any = root
|
||||
@@ -10,10 +17,10 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
|
||||
walk = walk.get(comp)
|
||||
else:
|
||||
return default
|
||||
return walk
|
||||
return cast(R, walk)
|
||||
|
||||
|
||||
def set_path_in_dict(root: dict, path: str, value: Any, sep="."):
|
||||
def set_path_in_dict[R: Any](root: rdict[R], path: str, value: R, sep: str = ".") -> None:
|
||||
"""Recursively walk through `root`, checking each part of `path` separated by `sep`
|
||||
and setting the last value to `value`"""
|
||||
# Walk each component of the path
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""file utils"""
|
||||
|
||||
from django.db.models import Model
|
||||
from django.http import HttpResponseBadRequest
|
||||
from django.http import HttpResponse, HttpResponseBadRequest
|
||||
from rest_framework.fields import BooleanField, CharField, FileField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@@ -25,7 +25,7 @@ class FilePathSerializer(PassiveSerializer):
|
||||
url = CharField()
|
||||
|
||||
|
||||
def set_file(request: Request, obj: Model, field_name: str):
|
||||
def set_file(request: Request, obj: Model, field_name: str) -> HttpResponse:
|
||||
"""Upload file"""
|
||||
field = getattr(obj, field_name)
|
||||
file = request.FILES.get("file", None)
|
||||
@@ -45,7 +45,7 @@ def set_file(request: Request, obj: Model, field_name: str):
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
|
||||
def set_file_url(request: Request, obj: Model, field_name: str):
|
||||
def set_file_url(request: Request, obj: Model, field_name: str) -> HttpResponse:
|
||||
"""Set file field to URL"""
|
||||
field = getattr(obj, field_name)
|
||||
url = request.data.get("url", None)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""http helpers"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from requests.models import Response
|
||||
from requests.sessions import PreparedRequest, Session
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import authentik_full_version
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from requests.sessions import _Timeout
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@@ -19,50 +24,40 @@ def authentik_user_agent() -> str:
|
||||
class TimeoutSession(Session):
|
||||
"""Always set a default HTTP request timeout"""
|
||||
|
||||
def __init__(self, default_timeout=None):
|
||||
def __init__(self, default_timeout: int | None = None) -> None:
|
||||
super().__init__()
|
||||
self.timeout = default_timeout
|
||||
|
||||
def send(
|
||||
self,
|
||||
request,
|
||||
request: PreparedRequest,
|
||||
*,
|
||||
stream=...,
|
||||
verify=...,
|
||||
proxies=...,
|
||||
cert=...,
|
||||
timeout=...,
|
||||
allow_redirects=...,
|
||||
**kwargs,
|
||||
):
|
||||
timeout: "_Timeout | None" = None,
|
||||
**kwargs: Any,
|
||||
) -> Response:
|
||||
if not timeout and self.timeout:
|
||||
timeout = self.timeout
|
||||
return super().send(
|
||||
request,
|
||||
stream=stream,
|
||||
verify=verify,
|
||||
proxies=proxies,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
**kwargs,
|
||||
)
|
||||
return super().send(request, timeout=timeout, **kwargs)
|
||||
|
||||
|
||||
class DebugSession(TimeoutSession):
|
||||
"""requests session which logs http requests and responses"""
|
||||
|
||||
def send(self, req: PreparedRequest, *args, **kwargs):
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
**kwargs: Any,
|
||||
) -> Response:
|
||||
request_id = str(uuid4())
|
||||
LOGGER.debug(
|
||||
"HTTP request sent",
|
||||
uid=request_id,
|
||||
url=req.url,
|
||||
method=req.method,
|
||||
headers=req.headers,
|
||||
body=req.body,
|
||||
url=request.url,
|
||||
method=request.method,
|
||||
headers=request.headers,
|
||||
body=request.body,
|
||||
)
|
||||
resp = super().send(req, *args, **kwargs)
|
||||
resp = super().send(request, **kwargs)
|
||||
LOGGER.debug(
|
||||
"HTTP response received",
|
||||
uid=request_id,
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""authentik lib reflection utilities"""
|
||||
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from typing import cast
|
||||
|
||||
from django.apps.config import AppConfig
|
||||
from django.conf import settings
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
@@ -13,9 +16,9 @@ from authentik.lib.config import CONFIG
|
||||
SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST"
|
||||
|
||||
|
||||
def all_subclasses[T: type](cls: T, sort=True) -> list[T] | set[T]:
|
||||
def all_subclasses[T: type](cls: T, sort: bool = True) -> list[T] | set[T]:
|
||||
"""Recursively return all subclassess of cls"""
|
||||
classes = set(cls.__subclasses__()).union(
|
||||
classes: list[T] | set[T] = set(cls.__subclasses__()).union(
|
||||
[s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)]
|
||||
)
|
||||
# Check if we're in debug mode, if not exclude classes which have `__debug_only__`
|
||||
@@ -38,10 +41,10 @@ def path_to_class(path: str = "") -> type:
|
||||
parts = path.split(".")
|
||||
package = ".".join(parts[:-1])
|
||||
_class = getattr(import_module(package), parts[-1])
|
||||
return _class
|
||||
return cast(type, _class)
|
||||
|
||||
|
||||
def get_apps():
|
||||
def get_apps() -> Generator[AppConfig]:
|
||||
"""Get list of all authentik apps"""
|
||||
from django.apps.registry import apps
|
||||
|
||||
@@ -65,11 +68,11 @@ def get_env() -> str:
|
||||
return "custom"
|
||||
|
||||
|
||||
def ConditionalInheritance(path: str):
|
||||
def ConditionalInheritance(path: str) -> type:
|
||||
"""Conditionally inherit from a class, intended for things like authentik.enterprise,
|
||||
without which authentik should still be able to run"""
|
||||
try:
|
||||
cls = import_string(path)
|
||||
return cls
|
||||
return cast(type, cls)
|
||||
except ModuleNotFoundError:
|
||||
return object
|
||||
|
||||
@@ -19,7 +19,7 @@ ALLOWED_KEYS = (
|
||||
)
|
||||
|
||||
|
||||
def timedelta_string_validator(value: str):
|
||||
def timedelta_string_validator(value: str) -> None:
|
||||
"""Validator for Django that checks if value can be parsed with `timedelta_from_string`"""
|
||||
try:
|
||||
timedelta_from_string(value)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""URL-related utils"""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.http import HttpResponse, QueryDict
|
||||
@@ -10,12 +11,12 @@ from structlog.stdlib import get_logger
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def is_url_absolute(url):
|
||||
def is_url_absolute(url: str | bytes | bytearray | None) -> bool:
|
||||
"""Check if domain is absolute to prevent user from being redirect somewhere else"""
|
||||
return bool(urlparse(url).netloc)
|
||||
|
||||
|
||||
def redirect_with_qs(view: str, get_query_set: QueryDict | None = None, **kwargs) -> HttpResponse:
|
||||
def redirect_with_qs(view: str, qs: QueryDict | None = None, **kwargs: Any) -> HttpResponse:
|
||||
"""Wrapper to redirect whilst keeping GET Parameters"""
|
||||
try:
|
||||
target = reverse(view, kwargs=kwargs)
|
||||
@@ -24,14 +25,14 @@ def redirect_with_qs(view: str, get_query_set: QueryDict | None = None, **kwargs
|
||||
return redirect(view)
|
||||
LOGGER.warning("redirect target is not a valid view", view=view)
|
||||
raise
|
||||
if get_query_set:
|
||||
target += "?" + get_query_set.urlencode()
|
||||
if qs:
|
||||
target += "?" + qs.urlencode()
|
||||
return redirect(target)
|
||||
|
||||
|
||||
def reverse_with_qs(view: str, query: QueryDict | None = None, **kwargs) -> str:
|
||||
def reverse_with_qs(view: str, qs: QueryDict | None = None, **kwargs: Any) -> str:
|
||||
"""Reverse a view to it's url but include get params"""
|
||||
url = reverse(view, **kwargs)
|
||||
if query:
|
||||
url += "?" + query.urlencode()
|
||||
if qs:
|
||||
url += "?" + qs.urlencode()
|
||||
return url
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Serializer validators"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.serializers import Serializer
|
||||
from rest_framework.utils.representation import smart_repr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.utils.functional import _StrPromise
|
||||
|
||||
_IN = TypeVar("_IN") # Instance Type
|
||||
|
||||
|
||||
class RequiredTogetherValidator:
|
||||
"""Serializer-level validator that ensures all fields in `fields` are only
|
||||
@@ -12,13 +19,13 @@ class RequiredTogetherValidator:
|
||||
|
||||
fields: list[str]
|
||||
requires_context = True
|
||||
message = _("The fields {field_names} must be used together.")
|
||||
message: "str | _StrPromise" = _("The fields {field_names} must be used together.")
|
||||
|
||||
def __init__(self, fields: list[str], message: str | None = None) -> None:
|
||||
def __init__(self, fields: list[str], message: "str | _StrPromise | None" = None) -> None:
|
||||
self.fields = fields
|
||||
self.message = message or self.message
|
||||
|
||||
def __call__(self, attrs: dict, serializer: Serializer):
|
||||
def __call__(self, attrs: dict[Any, Any], serializer: Serializer[_IN]) -> None:
|
||||
"""Check that if any of the fields in `self.fields` are set, all of them must be set"""
|
||||
if any(field in attrs for field in self.fields) and not all(
|
||||
field in attrs for field in self.fields
|
||||
@@ -27,5 +34,5 @@ class RequiredTogetherValidator:
|
||||
message = self.message.format(field_names=field_names)
|
||||
raise ValidationError(message, code="required")
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>"
|
||||
|
||||
@@ -8,8 +8,8 @@ from django.utils.translation import gettext_lazy as _
|
||||
def bad_request_message(
|
||||
request: HttpRequest,
|
||||
message: str,
|
||||
title="Bad Request",
|
||||
template="if/error.html",
|
||||
title: str = "Bad Request",
|
||||
template: str = "if/error.html",
|
||||
) -> TemplateResponse:
|
||||
"""Return generic error page with message, with status code set to 400"""
|
||||
return TemplateResponse(
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""XML Utilities"""
|
||||
|
||||
from lxml.etree import XMLParser, fromstring # nosec
|
||||
from lxml.etree import XMLParser, _Element, fromstring # nosec
|
||||
|
||||
|
||||
def get_lxml_parser():
|
||||
def get_lxml_parser() -> XMLParser:
|
||||
"""Get XML parser"""
|
||||
return XMLParser(resolve_entities=False)
|
||||
|
||||
|
||||
def lxml_from_string(text: str):
|
||||
def lxml_from_string(text: str) -> _Element:
|
||||
"""Wrapper around fromstring"""
|
||||
return fromstring(text, parser=get_lxml_parser()) # nosec
|
||||
|
||||
@@ -5,7 +5,6 @@ from multiprocessing.connection import Connection
|
||||
|
||||
from django.core.cache import cache
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
@@ -121,7 +120,7 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
)
|
||||
return policy_result
|
||||
|
||||
def profiling_wrapper(self):
|
||||
def profiling_wrapper(self) -> PolicyResult:
|
||||
"""Run with profiling enabled"""
|
||||
with (
|
||||
start_span(
|
||||
@@ -135,7 +134,6 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
mode="execute_process",
|
||||
).time(),
|
||||
):
|
||||
span: Span
|
||||
span.set_data("policy", self.binding.policy)
|
||||
span.set_data("request", self.request)
|
||||
return self.execute()
|
||||
|
||||
@@ -458,7 +458,7 @@ class BaseGrantModel(models.Model):
|
||||
return self._scope.split()
|
||||
|
||||
@scope.setter
|
||||
def scope(self, value):
|
||||
def scope(self, value: list[str]) -> None:
|
||||
self._scope = " ".join(value)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""SCIMProviderGroup API Views"""
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import PartialGroupSerializer
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
|
||||
from authentik.providers.scim.models import SCIMProviderGroup
|
||||
|
||||
|
||||
@@ -16,7 +12,6 @@ class SCIMProviderGroupSerializer(ModelSerializer):
|
||||
group_obj = PartialGroupSerializer(source="group", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = SCIMProviderGroup
|
||||
fields = [
|
||||
"id",
|
||||
@@ -29,15 +24,7 @@ class SCIMProviderGroupSerializer(ModelSerializer):
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class SCIMProviderGroupViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
OutgoingSyncConnectionCreateMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class SCIMProviderGroupViewSet(OutgoingSyncConnectionViewSet):
|
||||
"""SCIMProviderGroup Viewset"""
|
||||
|
||||
queryset = SCIMProviderGroup.objects.all().select_related("group")
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""SCIM Provider API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderViewSet
|
||||
from authentik.lib.utils.reflection import ConditionalInheritance
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
@@ -45,13 +42,16 @@ class SCIMProviderSerializer(
|
||||
extra_kwargs = {}
|
||||
|
||||
|
||||
class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
|
||||
class SCIMProviderViewSet(OutgoingSyncProviderViewSet):
|
||||
"""SCIMProvider Viewset"""
|
||||
|
||||
queryset = SCIMProvider.objects.all()
|
||||
serializer_class = SCIMProviderSerializer
|
||||
filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
|
||||
search_fields = ["name", "url"]
|
||||
ordering = ["name", "url"]
|
||||
filterset_fields = OutgoingSyncProviderViewSet.filterset_fields + [
|
||||
"url",
|
||||
]
|
||||
search_fields = OutgoingSyncProviderViewSet.search_fields + [
|
||||
"url",
|
||||
]
|
||||
sync_task = scim_sync
|
||||
sync_objects_task = scim_sync_objects
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""SCIMProviderUser API Views"""
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import PartialUserSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionViewSet
|
||||
from authentik.providers.scim.models import SCIMProviderUser
|
||||
|
||||
|
||||
@@ -16,7 +12,6 @@ class SCIMProviderUserSerializer(ModelSerializer):
|
||||
user_obj = PartialUserSerializer(source="user", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = SCIMProviderUser
|
||||
fields = [
|
||||
"id",
|
||||
@@ -29,15 +24,7 @@ class SCIMProviderUserSerializer(ModelSerializer):
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class SCIMProviderUserViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
OutgoingSyncConnectionCreateMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class SCIMProviderUserViewSet(OutgoingSyncConnectionViewSet):
|
||||
"""SCIMProviderUser Viewset"""
|
||||
|
||||
queryset = SCIMProviderUser.objects.all().select_related("user")
|
||||
|
||||
@@ -12,7 +12,7 @@ from requests.auth import AuthBase
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
|
||||
from authentik.core.models import Group, PropertyMapping, User, UserTypes
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
@@ -81,7 +81,7 @@ class SCIMCompatibilityMode(models.TextChoices):
|
||||
SALESFORCE = "sfdc", _("Salesforce")
|
||||
|
||||
|
||||
class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
class SCIMProvider(OutgoingSyncProvider):
|
||||
"""SCIM 2.0 provider to create users and groups in external applications"""
|
||||
|
||||
exclude_users_service_account = models.BooleanField(default=False)
|
||||
|
||||
@@ -253,19 +253,19 @@ class IdentificationStageView(ChallengeStageView):
|
||||
if current_stage.enrollment_flow:
|
||||
challenge.initial_data["enroll_url"] = reverse_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
query=get_qs,
|
||||
qs=get_qs,
|
||||
kwargs={"flow_slug": current_stage.enrollment_flow.slug},
|
||||
)
|
||||
if current_stage.recovery_flow:
|
||||
challenge.initial_data["recovery_url"] = reverse_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
query=get_qs,
|
||||
qs=get_qs,
|
||||
kwargs={"flow_slug": current_stage.recovery_flow.slug},
|
||||
)
|
||||
if current_stage.passwordless_flow:
|
||||
challenge.initial_data["passwordless_url"] = reverse_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
query=get_qs,
|
||||
qs=get_qs,
|
||||
kwargs={"flow_slug": current_stage.passwordless_flow.slug},
|
||||
)
|
||||
|
||||
|
||||
@@ -104,7 +104,9 @@ dev = [
|
||||
"requests-mock==1.12.1",
|
||||
"ruff==0.11.9",
|
||||
"selenium==4.32.0",
|
||||
"types-cachetools==6.2.0.20251022",
|
||||
"types-channels==4.3.0.20250822",
|
||||
"types-docker==7.1.0.20251009",
|
||||
"types-ldap3==2.9.13.20250622",
|
||||
]
|
||||
|
||||
@@ -208,7 +210,7 @@ plugins = ["mypy_django_plugin.main", "mypy_drf_plugin.main", "pydantic.mypy"]
|
||||
exclude = ['^gen-py-api/']
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["django_tenants.*", "dramatiq.*", "pglock.*"]
|
||||
module = ["django_tenants.*", "dramatiq.*", "pglock.*", "debugpy.*"]
|
||||
follow_untyped_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
@@ -226,7 +228,9 @@ module = [
|
||||
"authentik.enterprise.*",
|
||||
"authentik.events.*",
|
||||
"authentik.flows.*",
|
||||
"authentik.lib.*",
|
||||
"authentik.lib.config",
|
||||
"authentik.lib.models",
|
||||
"authentik.lib.tests.test_config",
|
||||
"authentik.outposts.*",
|
||||
"authentik.policies.*",
|
||||
"authentik.policies.dummy.*",
|
||||
|
||||
39
uv.lock
generated
39
uv.lock
generated
@@ -272,7 +272,9 @@ dev = [
|
||||
{ name = "requests-mock" },
|
||||
{ name = "ruff" },
|
||||
{ name = "selenium" },
|
||||
{ name = "types-cachetools" },
|
||||
{ name = "types-channels" },
|
||||
{ name = "types-docker" },
|
||||
{ name = "types-ldap3" },
|
||||
]
|
||||
|
||||
@@ -377,7 +379,9 @@ dev = [
|
||||
{ name = "requests-mock", specifier = "==1.12.1" },
|
||||
{ name = "ruff", specifier = "==0.11.9" },
|
||||
{ name = "selenium", specifier = "==4.32.0" },
|
||||
{ name = "types-cachetools", specifier = "==6.2.0.20251022" },
|
||||
{ name = "types-channels", specifier = "==4.3.0.20250822" },
|
||||
{ name = "types-docker", specifier = "==7.1.0.20251009" },
|
||||
{ name = "types-ldap3", specifier = "==2.9.13.20250622" },
|
||||
]
|
||||
|
||||
@@ -3389,6 +3393,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/bb/d43e5c75054e53efce310e79d63df0ac3f25e34c926be5dffb7d283fb2a8/typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1", size = 17605, upload-time = "2021-12-10T21:09:37.844Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-cachetools"
|
||||
version = "6.2.0.20251022"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3b/a8/f9bcc7f1be63af43ef0170a773e2d88817bcc7c9d8769f2228c802826efe/types_cachetools-6.2.0.20251022.tar.gz", hash = "sha256:f1d3c736f0f741e89ec10f0e1b0138625023e21eb33603a930c149e0318c0cef", size = 9608, upload-time = "2025-10-22T03:03:58.16Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/98/2d/8d821ed80f6c2c5b427f650bf4dc25b80676ed63d03388e4b637d2557107/types_cachetools-6.2.0.20251022-py3-none-any.whl", hash = "sha256:698eb17b8f16b661b90624708b6915f33dbac2d185db499ed57e4997e7962cad", size = 9341, upload-time = "2025-10-22T03:03:57.036Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-channels"
|
||||
version = "4.3.0.20250822"
|
||||
@@ -3402,6 +3415,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/52/4e3094e43d460feacb9051ec4c3498f8272f69d92b772647211478b25079/types_channels-4.3.0.20250822-py3-none-any.whl", hash = "sha256:d3fc0a1467c8cc901686826408c8a673822e07aa79cbe1a6d21946e7e55d9ddf", size = 21125, upload-time = "2025-08-22T03:04:25.539Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-docker"
|
||||
version = "7.1.0.20251009"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "types-paramiko" },
|
||||
{ name = "types-requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/93/9d/c9427adb04df55f3821b042612f8f7555c7060d6a3b589549a10b7a15c3a/types_docker-7.1.0.20251009.tar.gz", hash = "sha256:37af2a9ed5c3d76308ee9b9958cf1506fe9bcfbfed9c0a20bd9856dbca90424e", size = 31647, upload-time = "2025-10-09T02:54:40.976Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/71/bb/da793021a87528e2ca717f117d0d94149c429425a33dea036675932a0170/types_docker-7.1.0.20251009-py3-none-any.whl", hash = "sha256:e0ed83c70b824d0efffca6e61662e2722109207515579782fa27c505ea06fb7d", size = 46417, upload-time = "2025-10-09T02:54:40.035Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-ldap3"
|
||||
version = "2.9.13.20250622"
|
||||
@@ -3414,6 +3441,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/fd/0339a618985d19d9b0630f78822d1becb0661be6abe8adbadd9569b875e1/types_ldap3-2.9.13.20250622-py3-none-any.whl", hash = "sha256:c18d0320327fa0017eb3d95acdf38921542d80939255e4ba130ca2d13ca3375f", size = 56498, upload-time = "2025-06-22T03:19:15.495Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-paramiko"
|
||||
version = "4.0.0.20250822"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b7/b8/c6ff3b10c2f7b9897650af746f0dc6c5cddf054db857bc79d621f53c7d22/types_paramiko-4.0.0.20250822.tar.gz", hash = "sha256:1b56b0cbd3eec3d2fd123c9eb2704e612b777e15a17705a804279ea6525e0c53", size = 28730, upload-time = "2025-08-22T03:03:43.262Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/a1/b3774ed924a66ee2c041224d89c36f0c21f4f6cf75036d6ee7698bf8a4b9/types_paramiko-4.0.0.20250822-py3-none-any.whl", hash = "sha256:55bdb14db75ca89039725ec64ae3fa26b8d57b6991cfb476212fa8f83a59753c", size = 38833, upload-time = "2025-08-22T03:03:42.072Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-pyasn1"
|
||||
version = "0.6.0.20250914"
|
||||
|
||||
Reference in New Issue
Block a user