mirror of
https://github.com/goauthentik/authentik
synced 2026-05-05 22:52:42 +02:00
Compare commits
2 Commits
metadata-f
...
mokeytype_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7325bfde36 | ||
|
|
6383d0bfc0 |
@@ -77,7 +77,7 @@ class GroupSerializer(ModelSerializer):
|
||||
return None
|
||||
return GroupMemberSerializer(instance.users, many=True).data
|
||||
|
||||
def validate_parent(self, parent: Group | None):
|
||||
def validate_parent(self, parent: Group | None) -> None:
|
||||
"""Validate group parent (if set), ensuring the parent isn't itself"""
|
||||
if not self.instance or not parent:
|
||||
return parent
|
||||
@@ -85,7 +85,7 @@ class GroupSerializer(ModelSerializer):
|
||||
raise ValidationError(_("Cannot set group as parent of itself."))
|
||||
return parent
|
||||
|
||||
def validate_is_superuser(self, superuser: bool):
|
||||
def validate_is_superuser(self, superuser: bool) -> bool:
|
||||
"""Ensure that the user creating this group has permissions to set the superuser flag"""
|
||||
request: Request = self.context.get("request", None)
|
||||
if not request:
|
||||
@@ -210,7 +210,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
]
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
def list(self, request: Request, *args, **kwargs) -> Response:
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
@extend_schema(
|
||||
@@ -218,7 +218,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
]
|
||||
)
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
def retrieve(self, request: Request, *args, **kwargs) -> Response:
|
||||
return super().retrieve(request, *args, **kwargs)
|
||||
|
||||
@permission_required("authentik_core.add_user_to_group")
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.db.models.query import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters.filters import BooleanFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from model_utils.managers import InheritanceQuerySet
|
||||
from rest_framework import mixins
|
||||
from rest_framework.fields import ReadOnlyField, SerializerMethodField
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
@@ -99,5 +100,5 @@ class ProviderViewSet(
|
||||
"application__name",
|
||||
]
|
||||
|
||||
def get_queryset(self): # pragma: no cover
|
||||
def get_queryset(self) -> InheritanceQuerySet: # pragma: no cover
|
||||
return Provider.objects.select_subclasses()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from model_utils.managers import InheritanceQuerySet
|
||||
from rest_framework import mixins
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
@@ -88,7 +89,7 @@ class SourceViewSet(
|
||||
search_fields = ["slug", "name"]
|
||||
filterset_fields = ["slug", "name", "managed", "pbm_uuid"]
|
||||
|
||||
def get_queryset(self): # pragma: no cover
|
||||
def get_queryset(self) -> InheritanceQuerySet: # pragma: no cover
|
||||
return Source.objects.select_subclasses()
|
||||
|
||||
@permission_required("authentik_core.change_source")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer
|
||||
from guardian.shortcuts import assign_perm, get_anonymous_user
|
||||
@@ -41,7 +42,7 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
||||
self.fields["key"] = CharField(required=False)
|
||||
|
||||
def validate_user(self, user: User):
|
||||
def validate_user(self, user: User) -> User:
|
||||
"""Ensure user of token cannot be changed"""
|
||||
if self.instance and self.instance.user_id:
|
||||
if user.pk != self.instance.user_id:
|
||||
@@ -138,13 +139,13 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
|
||||
owner_field = "user"
|
||||
rbac_allow_create_without_perm = True
|
||||
|
||||
def get_queryset(self):
|
||||
def get_queryset(self) -> QuerySet:
|
||||
user = self.request.user if self.request else get_anonymous_user()
|
||||
if user.is_superuser:
|
||||
return super().get_queryset()
|
||||
return super().get_queryset().filter(user=user.pk)
|
||||
|
||||
def perform_create(self, serializer: TokenSerializer):
|
||||
def perform_create(self, serializer: TokenSerializer) -> Token:
|
||||
if not self.request.user.is_superuser:
|
||||
instance = serializer.save(
|
||||
user=self.request.user,
|
||||
|
||||
@@ -21,6 +21,7 @@ from django_filters.filters import (
|
||||
UUIDFilter,
|
||||
)
|
||||
from django_filters.filterset import FilterSet
|
||||
from djangoql.schema import BoolField, StrField
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import (
|
||||
OpenApiParameter,
|
||||
@@ -72,8 +73,10 @@ from authentik.core.models import (
|
||||
Token,
|
||||
TokenIntents,
|
||||
User,
|
||||
UserQuerySet,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import FlowToken
|
||||
@@ -349,7 +352,7 @@ class UsersFilter(FilterSet):
|
||||
queryset=Group.objects.all().order_by("name"),
|
||||
)
|
||||
|
||||
def filter_is_superuser(self, queryset, name, value):
|
||||
def filter_is_superuser(self, queryset: UserQuerySet, name: str, value: bool) -> UserQuerySet:
|
||||
if value:
|
||||
return queryset.filter(ak_groups__is_superuser=True).distinct()
|
||||
return queryset.exclude(ak_groups__is_superuser=True).distinct()
|
||||
@@ -395,7 +398,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
filterset_class = UsersFilter
|
||||
search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
|
||||
|
||||
def get_ql_fields(self):
|
||||
def get_ql_fields(self) -> list[StrField | BoolField | ChoiceSearchField | JSONSearchField]:
|
||||
from djangoql.schema import BoolField, StrField
|
||||
|
||||
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
|
||||
@@ -410,7 +413,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
JSONSearchField(User, "attributes", suggest_nested=False),
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
def get_queryset(self) -> UserQuerySet:
|
||||
base_qs = User.objects.all().exclude_anonymous()
|
||||
if self.serializer_class(context={"request": self.request})._should_include_groups:
|
||||
base_qs = base_qs.prefetch_related("ak_groups")
|
||||
@@ -421,10 +424,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
OpenApiParameter("include_groups", bool, default=True),
|
||||
]
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
def list(self, request: Request, *args, **kwargs) -> Response:
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
def _create_recovery_link(self, for_email=False) -> tuple[str, Token]:
|
||||
def _create_recovery_link(self, for_email: bool = False) -> tuple[str, Token]:
|
||||
"""Create a recovery link (when the current brand has a recovery flow set),
|
||||
that can either be shown to an admin or sent to the user directly"""
|
||||
brand: Brand = self.request._request.brand
|
||||
|
||||
@@ -42,7 +42,7 @@ class JSONExtension(OpenApiSerializerFieldExtension):
|
||||
|
||||
target_class = "authentik.core.api.utils.JSONDictField"
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
def map_serializer_field(self, auto_schema, direction: str) -> dict[str, str]:
|
||||
return build_basic_type(OpenApiTypes.OBJECT)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class ModelSerializer(BaseModelSerializer):
|
||||
serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy()
|
||||
serializer_field_mapping[models.JSONField] = JSONDictField
|
||||
|
||||
def create(self, validated_data):
|
||||
def create(self, validated_data: dict[str, Any]):
|
||||
instance = super().create(validated_data)
|
||||
|
||||
request = self.context.get("request")
|
||||
@@ -61,7 +61,7 @@ class ModelSerializer(BaseModelSerializer):
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance: Model, validated_data):
|
||||
def update(self, instance: Model, validated_data: dict[str, Any]):
|
||||
raise_errors_on_nested_writes("update", self, validated_data)
|
||||
info = model_meta.get_field_info(instance)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
@@ -14,6 +15,8 @@ from django.utils.translation import override
|
||||
from sentry_sdk.api import set_tag
|
||||
from structlog.contextvars import STRUCTLOG_KEY_PREFIX
|
||||
|
||||
from authentik.core.models import User
|
||||
|
||||
SESSION_KEY_IMPERSONATE_USER = "authentik/impersonate/user"
|
||||
SESSION_KEY_IMPERSONATE_ORIGINAL_USER = "authentik/impersonate/original_user"
|
||||
RESPONSE_HEADER_ID = "X-authentik-id"
|
||||
@@ -25,7 +28,7 @@ CTX_HOST = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "host", default=None)
|
||||
CTX_AUTH_VIA = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None)
|
||||
|
||||
|
||||
def get_user(request):
|
||||
def get_user(request: WSGIRequest) -> AnonymousUser | User:
|
||||
if not hasattr(request, "_cached_user"):
|
||||
user = None
|
||||
if (authenticated_session := request.session.get("authenticatedsession", None)) is not None:
|
||||
@@ -46,7 +49,7 @@ async def aget_user(request):
|
||||
|
||||
|
||||
class AuthenticationMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
def process_request(self, request: WSGIRequest):
|
||||
if not hasattr(request, "session"):
|
||||
raise ImproperlyConfigured(
|
||||
"The Django authentication middleware requires session "
|
||||
|
||||
@@ -11,6 +11,7 @@ from django.contrib.auth.hashers import check_password
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.auth.models import UserManager as DjangoUserManager
|
||||
from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.db import models
|
||||
from django.db.models import Q, QuerySet, options
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
@@ -22,6 +23,7 @@ from django_cte import CTE, with_cte
|
||||
from guardian.conf import settings
|
||||
from guardian.mixins import GuardianUserMixin
|
||||
from model_utils.managers import InheritanceManager
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -137,7 +139,7 @@ class AttributesMixin(models.Model):
|
||||
|
||||
|
||||
class GroupQuerySet(QuerySet):
|
||||
def with_children_recursive(self):
|
||||
def with_children_recursive(self) -> "GroupQuerySet":
|
||||
"""Recursively get all groups that have the current queryset as parents
|
||||
or are indirectly related."""
|
||||
|
||||
@@ -210,7 +212,7 @@ class Group(SerializerModel, AttributesMixin):
|
||||
("disable_group_superuser", _("Disable superuser status")),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"Group {self.name}"
|
||||
|
||||
@property
|
||||
@@ -241,7 +243,7 @@ class Group(SerializerModel, AttributesMixin):
|
||||
class UserQuerySet(models.QuerySet):
|
||||
"""User queryset"""
|
||||
|
||||
def exclude_anonymous(self):
|
||||
def exclude_anonymous(self) -> "UserQuerySet":
|
||||
"""Exclude anonymous user"""
|
||||
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
|
||||
|
||||
@@ -249,7 +251,7 @@ class UserQuerySet(models.QuerySet):
|
||||
class UserManager(DjangoUserManager):
|
||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||
|
||||
def get_queryset(self):
|
||||
def get_queryset(self) -> UserQuerySet:
|
||||
"""Create special user queryset"""
|
||||
return UserQuerySet(self.model, using=self._db)
|
||||
|
||||
@@ -295,7 +297,7 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
|
||||
models.Index(fields=["type"]),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.username
|
||||
|
||||
@staticmethod
|
||||
@@ -360,7 +362,13 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
|
||||
"""superuser == staff user"""
|
||||
return self.is_superuser # type: ignore
|
||||
|
||||
def set_password(self, raw_password, signal=True, sender=None, request=None):
|
||||
def set_password(
|
||||
self,
|
||||
raw_password: str,
|
||||
signal: bool = True,
|
||||
sender: None = None,
|
||||
request: WSGIRequest | Request | None = None,
|
||||
) -> None:
|
||||
if self.pk and signal:
|
||||
from authentik.core.signals import password_changed
|
||||
|
||||
@@ -479,7 +487,7 @@ class Provider(SerializerModel):
|
||||
"""Get serializer for this model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
|
||||
@@ -611,7 +619,7 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
)
|
||||
return getattr(providers.first(), provider_type._meta.model_name)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
class Meta:
|
||||
@@ -631,7 +639,7 @@ class ApplicationEntitlement(AttributesMixin, SerializerModel, PolicyBindingMode
|
||||
verbose_name_plural = _("Application Entitlements")
|
||||
unique_together = (("app", "name"),)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"Application Entitlement {self.name} for app {self.app_id}"
|
||||
|
||||
@property
|
||||
@@ -640,7 +648,7 @@ class ApplicationEntitlement(AttributesMixin, SerializerModel, PolicyBindingMode
|
||||
|
||||
return ApplicationEntitlementSerializer
|
||||
|
||||
def supported_policy_binding_targets(self):
|
||||
def supported_policy_binding_targets(self) -> list[str]:
|
||||
return ["group", "user"]
|
||||
|
||||
|
||||
@@ -812,7 +820,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
return {}
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
class Meta:
|
||||
@@ -895,7 +903,7 @@ class ExpiringModel(models.Model):
|
||||
models.Index(fields=["expiring", "expires"]),
|
||||
]
|
||||
|
||||
def expire_action(self, *args, **kwargs):
|
||||
def expire_action(self, *args, **kwargs) -> tuple[int, dict[str, int]]:
|
||||
"""Handler which is called when this object is expired. By
|
||||
default the object is deleted. This is less efficient compared
|
||||
to bulk deleting objects, but classes like Token() need to change
|
||||
@@ -958,7 +966,7 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
|
||||
("set_token_key", _("Set a token's key")),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
description = f"{self.identifier}"
|
||||
if self.expiring:
|
||||
description += f" (expires={self.expires})"
|
||||
@@ -1023,7 +1031,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
||||
except Exception as exc:
|
||||
raise PropertyMappingExpressionException(exc, self) from exc
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"Property Mapping {self.name}"
|
||||
|
||||
class Meta:
|
||||
@@ -1051,7 +1059,7 @@ class Session(ExpiringModel, AbstractBaseSession):
|
||||
]
|
||||
default_permissions = []
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.session_key
|
||||
|
||||
class Keys(StrEnum):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""authentik sessions engine"""
|
||||
|
||||
import pickle # nosec
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
|
||||
from django.contrib.sessions.backends.db import SessionStore as SessionBase
|
||||
@@ -9,13 +10,19 @@ from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Session
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class SessionStore(SessionBase):
|
||||
def __init__(self, session_key=None, last_ip=None, last_user_agent=""):
|
||||
def __init__(
|
||||
self,
|
||||
session_key: str | None = None,
|
||||
last_ip: str | None = None,
|
||||
last_user_agent: str = "",
|
||||
):
|
||||
super().__init__(session_key)
|
||||
self._create_kwargs = {
|
||||
"last_ip": last_ip or ClientIPMiddleware.default_ip,
|
||||
@@ -23,16 +30,16 @@ class SessionStore(SessionBase):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls):
|
||||
def get_model_class(cls) -> type[Session]:
|
||||
from authentik.core.models import Session
|
||||
|
||||
return Session
|
||||
|
||||
@cached_property
|
||||
def model_fields(self):
|
||||
def model_fields(self) -> list[str]:
|
||||
return [k.value for k in self.model.Keys]
|
||||
|
||||
def _get_session_from_db(self):
|
||||
def _get_session_from_db(self) -> Session:
|
||||
try:
|
||||
return (
|
||||
self.model.objects.select_related(
|
||||
@@ -74,10 +81,10 @@ class SessionStore(SessionBase):
|
||||
LOGGER.warning(str(exc))
|
||||
self._session_key = None
|
||||
|
||||
def encode(self, session_dict):
|
||||
def encode(self, session_dict: dict[str, Any]) -> bytes:
|
||||
return pickle.dumps(session_dict, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def decode(self, session_data):
|
||||
def decode(self, session_data: bytes) -> dict[str, Any]:
|
||||
try:
|
||||
return pickle.loads(session_data) # nosec
|
||||
except pickle.PickleError:
|
||||
@@ -86,7 +93,7 @@ class SessionStore(SessionBase):
|
||||
pass
|
||||
return {}
|
||||
|
||||
def load(self):
|
||||
def load(self) -> dict[str, Any]:
|
||||
s = self._get_session_from_db()
|
||||
if s:
|
||||
return {
|
||||
@@ -108,7 +115,7 @@ class SessionStore(SessionBase):
|
||||
else:
|
||||
return {}
|
||||
|
||||
def create_model_instance(self, data):
|
||||
def create_model_instance(self, data: dict[str, Any]) -> Session:
|
||||
args = {
|
||||
"session_key": self._get_or_create_session_key(),
|
||||
"expires": self.get_expiry_date(),
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from django.http.response import (
|
||||
HttpResponseBadRequest,
|
||||
HttpResponseForbidden,
|
||||
HttpResponseNotAllowed,
|
||||
HttpResponseNotFound,
|
||||
HttpResponseServerError,
|
||||
)
|
||||
@@ -61,6 +62,6 @@ class ServerErrorView(TemplateView):
|
||||
response_class = ServerErrorTemplateResponse
|
||||
template_name = "if/error.html"
|
||||
|
||||
def dispatch(self, *args, **kwargs): # pragma: no cover
|
||||
def dispatch(self, *args, **kwargs) -> HttpResponseNotAllowed: # pragma: no cover
|
||||
"""Little wrapper so django accepts this function"""
|
||||
return super().dispatch(*args, **kwargs)
|
||||
|
||||
20
monkeytype_config.py
Normal file
20
monkeytype_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Standard Library
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
# 3rd-party
|
||||
from monkeytype.config import DefaultConfig
|
||||
|
||||
|
||||
class MonkeyConfig(DefaultConfig):
|
||||
@contextmanager
|
||||
def cli_context(self, command: str) -> Iterator[None]:
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||
import django
|
||||
|
||||
django.setup()
|
||||
yield
|
||||
|
||||
|
||||
CONFIG = MonkeyConfig()
|
||||
Reference in New Issue
Block a user