Compare commits

...

2 Commits

Author SHA1 Message Date
Marcelo Elizeche Landó
7325bfde36 add monkeytype_config.py 2025-07-23 13:10:03 -03:00
Marcelo Elizeche Landó
6383d0bfc0 Add types using typemonkey 2025-07-18 02:28:56 -03:00
11 changed files with 88 additions and 43 deletions

View File

@@ -77,7 +77,7 @@ class GroupSerializer(ModelSerializer):
return None
return GroupMemberSerializer(instance.users, many=True).data
def validate_parent(self, parent: Group | None):
def validate_parent(self, parent: Group | None) -> None:
"""Validate group parent (if set), ensuring the parent isn't itself"""
if not self.instance or not parent:
return parent
@@ -85,7 +85,7 @@ class GroupSerializer(ModelSerializer):
raise ValidationError(_("Cannot set group as parent of itself."))
return parent
def validate_is_superuser(self, superuser: bool):
def validate_is_superuser(self, superuser: bool) -> bool:
"""Ensure that the user creating this group has permissions to set the superuser flag"""
request: Request = self.context.get("request", None)
if not request:
@@ -210,7 +210,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
OpenApiParameter("include_users", bool, default=True),
]
)
def list(self, request, *args, **kwargs):
def list(self, request: Request, *args, **kwargs) -> Response:
return super().list(request, *args, **kwargs)
@extend_schema(
@@ -218,7 +218,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
OpenApiParameter("include_users", bool, default=True),
]
)
def retrieve(self, request, *args, **kwargs):
def retrieve(self, request: Request, *args, **kwargs) -> Response:
return super().retrieve(request, *args, **kwargs)
@permission_required("authentik_core.add_user_to_group")

View File

@@ -5,6 +5,7 @@ from django.db.models.query import Q
from django.utils.translation import gettext_lazy as _
from django_filters.filters import BooleanFilter
from django_filters.filterset import FilterSet
from model_utils.managers import InheritanceQuerySet
from rest_framework import mixins
from rest_framework.fields import ReadOnlyField, SerializerMethodField
from rest_framework.viewsets import GenericViewSet
@@ -99,5 +100,5 @@ class ProviderViewSet(
"application__name",
]
def get_queryset(self): # pragma: no cover
def get_queryset(self) -> InheritanceQuerySet: # pragma: no cover
return Provider.objects.select_subclasses()

View File

@@ -3,6 +3,7 @@
from collections.abc import Iterable
from drf_spectacular.utils import OpenApiResponse, extend_schema
from model_utils.managers import InheritanceQuerySet
from rest_framework import mixins
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
@@ -88,7 +89,7 @@ class SourceViewSet(
search_fields = ["slug", "name"]
filterset_fields = ["slug", "name", "managed", "pbm_uuid"]
def get_queryset(self): # pragma: no cover
def get_queryset(self) -> InheritanceQuerySet: # pragma: no cover
return Source.objects.select_subclasses()
@permission_required("authentik_core.change_source")

View File

@@ -2,6 +2,7 @@
from typing import Any
from django.db.models.query import QuerySet
from django.utils.timezone import now
from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer
from guardian.shortcuts import assign_perm, get_anonymous_user
@@ -41,7 +42,7 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["key"] = CharField(required=False)
def validate_user(self, user: User):
def validate_user(self, user: User) -> User:
"""Ensure user of token cannot be changed"""
if self.instance and self.instance.user_id:
if user.pk != self.instance.user_id:
@@ -138,13 +139,13 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
owner_field = "user"
rbac_allow_create_without_perm = True
def get_queryset(self):
def get_queryset(self) -> QuerySet:
user = self.request.user if self.request else get_anonymous_user()
if user.is_superuser:
return super().get_queryset()
return super().get_queryset().filter(user=user.pk)
def perform_create(self, serializer: TokenSerializer):
def perform_create(self, serializer: TokenSerializer) -> Token:
if not self.request.user.is_superuser:
instance = serializer.save(
user=self.request.user,

View File

@@ -21,6 +21,7 @@ from django_filters.filters import (
UUIDFilter,
)
from django_filters.filterset import FilterSet
from djangoql.schema import BoolField, StrField
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
OpenApiParameter,
@@ -72,8 +73,10 @@ from authentik.core.models import (
Token,
TokenIntents,
User,
UserQuerySet,
UserTypes,
)
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
from authentik.events.models import Event, EventAction
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import FlowToken
@@ -349,7 +352,7 @@ class UsersFilter(FilterSet):
queryset=Group.objects.all().order_by("name"),
)
def filter_is_superuser(self, queryset, name, value):
def filter_is_superuser(self, queryset: UserQuerySet, name: str, value: bool) -> UserQuerySet:
if value:
return queryset.filter(ak_groups__is_superuser=True).distinct()
return queryset.exclude(ak_groups__is_superuser=True).distinct()
@@ -395,7 +398,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
filterset_class = UsersFilter
search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
def get_ql_fields(self):
def get_ql_fields(self) -> list[StrField | BoolField | ChoiceSearchField | JSONSearchField]:
from djangoql.schema import BoolField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
@@ -410,7 +413,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
JSONSearchField(User, "attributes", suggest_nested=False),
]
def get_queryset(self):
def get_queryset(self) -> UserQuerySet:
base_qs = User.objects.all().exclude_anonymous()
if self.serializer_class(context={"request": self.request})._should_include_groups:
base_qs = base_qs.prefetch_related("ak_groups")
@@ -421,10 +424,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
OpenApiParameter("include_groups", bool, default=True),
]
)
def list(self, request, *args, **kwargs):
def list(self, request: Request, *args, **kwargs) -> Response:
return super().list(request, *args, **kwargs)
def _create_recovery_link(self, for_email=False) -> tuple[str, Token]:
def _create_recovery_link(self, for_email: bool = False) -> tuple[str, Token]:
"""Create a recovery link (when the current brand has a recovery flow set),
that can either be shown to an admin or sent to the user directly"""
brand: Brand = self.request._request.brand

View File

@@ -42,7 +42,7 @@ class JSONExtension(OpenApiSerializerFieldExtension):
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
def map_serializer_field(self, auto_schema, direction: str) -> dict[str, str]:
return build_basic_type(OpenApiTypes.OBJECT)
@@ -52,7 +52,7 @@ class ModelSerializer(BaseModelSerializer):
serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy()
serializer_field_mapping[models.JSONField] = JSONDictField
def create(self, validated_data):
def create(self, validated_data: dict[str, Any]):
instance = super().create(validated_data)
request = self.context.get("request")
@@ -61,7 +61,7 @@ class ModelSerializer(BaseModelSerializer):
return instance
def update(self, instance: Model, validated_data):
def update(self, instance: Model, validated_data: dict[str, Any]):
raise_errors_on_nested_writes("update", self, validated_data)
info = model_meta.get_field_info(instance)

View File

@@ -7,6 +7,7 @@ from uuid import uuid4
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIRequest
from django.http import HttpRequest, HttpResponse
from django.utils.deprecation import MiddlewareMixin
from django.utils.functional import SimpleLazyObject
@@ -14,6 +15,8 @@ from django.utils.translation import override
from sentry_sdk.api import set_tag
from structlog.contextvars import STRUCTLOG_KEY_PREFIX
from authentik.core.models import User
SESSION_KEY_IMPERSONATE_USER = "authentik/impersonate/user"
SESSION_KEY_IMPERSONATE_ORIGINAL_USER = "authentik/impersonate/original_user"
RESPONSE_HEADER_ID = "X-authentik-id"
@@ -25,7 +28,7 @@ CTX_HOST = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "host", default=None)
CTX_AUTH_VIA = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None)
def get_user(request):
def get_user(request: WSGIRequest) -> AnonymousUser | User:
if not hasattr(request, "_cached_user"):
user = None
if (authenticated_session := request.session.get("authenticatedsession", None)) is not None:
@@ -46,7 +49,7 @@ async def aget_user(request):
class AuthenticationMiddleware(MiddlewareMixin):
def process_request(self, request):
def process_request(self, request: WSGIRequest):
if not hasattr(request, "session"):
raise ImproperlyConfigured(
"The Django authentication middleware requires session "

View File

@@ -11,6 +11,7 @@ from django.contrib.auth.hashers import check_password
from django.contrib.auth.models import AbstractUser
from django.contrib.auth.models import UserManager as DjangoUserManager
from django.contrib.sessions.base_session import AbstractBaseSession
from django.core.handlers.wsgi import WSGIRequest
from django.db import models
from django.db.models import Q, QuerySet, options
from django.db.models.constants import LOOKUP_SEP
@@ -22,6 +23,7 @@ from django_cte import CTE, with_cte
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager
from rest_framework.request import Request
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
@@ -137,7 +139,7 @@ class AttributesMixin(models.Model):
class GroupQuerySet(QuerySet):
def with_children_recursive(self):
def with_children_recursive(self) -> "GroupQuerySet":
"""Recursively get all groups that have the current queryset as parents
or are indirectly related."""
@@ -210,7 +212,7 @@ class Group(SerializerModel, AttributesMixin):
("disable_group_superuser", _("Disable superuser status")),
]
def __str__(self):
def __str__(self) -> str:
return f"Group {self.name}"
@property
@@ -241,7 +243,7 @@ class Group(SerializerModel, AttributesMixin):
class UserQuerySet(models.QuerySet):
"""User queryset"""
def exclude_anonymous(self):
def exclude_anonymous(self) -> "UserQuerySet":
"""Exclude anonymous user"""
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
@@ -249,7 +251,7 @@ class UserQuerySet(models.QuerySet):
class UserManager(DjangoUserManager):
"""User manager that doesn't assign is_superuser and is_staff"""
def get_queryset(self):
def get_queryset(self) -> UserQuerySet:
"""Create special user queryset"""
return UserQuerySet(self.model, using=self._db)
@@ -295,7 +297,7 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
models.Index(fields=["type"]),
]
def __str__(self):
def __str__(self) -> str:
return self.username
@staticmethod
@@ -360,7 +362,13 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
"""superuser == staff user"""
return self.is_superuser # type: ignore
def set_password(self, raw_password, signal=True, sender=None, request=None):
def set_password(
self,
raw_password: str,
signal: bool = True,
sender: None = None,
request: WSGIRequest | Request | None = None,
) -> None:
if self.pk and signal:
from authentik.core.signals import password_changed
@@ -479,7 +487,7 @@ class Provider(SerializerModel):
"""Get serializer for this model"""
raise NotImplementedError
def __str__(self):
def __str__(self) -> str:
return str(self.name)
@@ -611,7 +619,7 @@ class Application(SerializerModel, PolicyBindingModel):
)
return getattr(providers.first(), provider_type._meta.model_name)
def __str__(self):
def __str__(self) -> str:
return str(self.name)
class Meta:
@@ -631,7 +639,7 @@ class ApplicationEntitlement(AttributesMixin, SerializerModel, PolicyBindingMode
verbose_name_plural = _("Application Entitlements")
unique_together = (("app", "name"),)
def __str__(self):
def __str__(self) -> str:
return f"Application Entitlement {self.name} for app {self.app_id}"
@property
@@ -640,7 +648,7 @@ class ApplicationEntitlement(AttributesMixin, SerializerModel, PolicyBindingMode
return ApplicationEntitlementSerializer
def supported_policy_binding_targets(self):
def supported_policy_binding_targets(self) -> list[str]:
return ["group", "user"]
@@ -812,7 +820,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
return {}
raise NotImplementedError
def __str__(self):
def __str__(self) -> str:
return str(self.name)
class Meta:
@@ -895,7 +903,7 @@ class ExpiringModel(models.Model):
models.Index(fields=["expiring", "expires"]),
]
def expire_action(self, *args, **kwargs):
def expire_action(self, *args, **kwargs) -> tuple[int, dict[str, int]]:
"""Handler which is called when this object is expired. By
default the object is deleted. This is less efficient compared
to bulk deleting objects, but classes like Token() need to change
@@ -958,7 +966,7 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
("set_token_key", _("Set a token's key")),
]
def __str__(self):
def __str__(self) -> str:
description = f"{self.identifier}"
if self.expiring:
description += f" (expires={self.expires})"
@@ -1023,7 +1031,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
except Exception as exc:
raise PropertyMappingExpressionException(exc, self) from exc
def __str__(self):
def __str__(self) -> str:
return f"Property Mapping {self.name}"
class Meta:
@@ -1051,7 +1059,7 @@ class Session(ExpiringModel, AbstractBaseSession):
]
default_permissions = []
def __str__(self):
def __str__(self) -> str:
return self.session_key
class Keys(StrEnum):

View File

@@ -1,6 +1,7 @@
"""authentik sessions engine"""
import pickle # nosec
from typing import Any
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
from django.contrib.sessions.backends.db import SessionStore as SessionBase
@@ -9,13 +10,19 @@ from django.utils import timezone
from django.utils.functional import cached_property
from structlog.stdlib import get_logger
from authentik.core.models import Session
from authentik.root.middleware import ClientIPMiddleware
LOGGER = get_logger()
class SessionStore(SessionBase):
def __init__(self, session_key=None, last_ip=None, last_user_agent=""):
def __init__(
self,
session_key: str | None = None,
last_ip: str | None = None,
last_user_agent: str = "",
):
super().__init__(session_key)
self._create_kwargs = {
"last_ip": last_ip or ClientIPMiddleware.default_ip,
@@ -23,16 +30,16 @@ class SessionStore(SessionBase):
}
@classmethod
def get_model_class(cls):
def get_model_class(cls) -> type[Session]:
from authentik.core.models import Session
return Session
@cached_property
def model_fields(self):
def model_fields(self) -> list[str]:
return [k.value for k in self.model.Keys]
def _get_session_from_db(self):
def _get_session_from_db(self) -> Session:
try:
return (
self.model.objects.select_related(
@@ -74,10 +81,10 @@ class SessionStore(SessionBase):
LOGGER.warning(str(exc))
self._session_key = None
def encode(self, session_dict):
def encode(self, session_dict: dict[str, Any]) -> bytes:
return pickle.dumps(session_dict, protocol=pickle.HIGHEST_PROTOCOL)
def decode(self, session_data):
def decode(self, session_data: bytes) -> dict[str, Any]:
try:
return pickle.loads(session_data) # nosec
except pickle.PickleError:
@@ -86,7 +93,7 @@ class SessionStore(SessionBase):
pass
return {}
def load(self):
def load(self) -> dict[str, Any]:
s = self._get_session_from_db()
if s:
return {
@@ -108,7 +115,7 @@ class SessionStore(SessionBase):
else:
return {}
def create_model_instance(self, data):
def create_model_instance(self, data: dict[str, Any]) -> Session:
args = {
"session_key": self._get_or_create_session_key(),
"expires": self.get_expiry_date(),

View File

@@ -3,6 +3,7 @@
from django.http.response import (
HttpResponseBadRequest,
HttpResponseForbidden,
HttpResponseNotAllowed,
HttpResponseNotFound,
HttpResponseServerError,
)
@@ -61,6 +62,6 @@ class ServerErrorView(TemplateView):
response_class = ServerErrorTemplateResponse
template_name = "if/error.html"
def dispatch(self, *args, **kwargs): # pragma: no cover
def dispatch(self, *args, **kwargs) -> HttpResponseNotAllowed: # pragma: no cover
"""Little wrapper so django accepts this function"""
return super().dispatch(*args, **kwargs)

20
monkeytype_config.py Normal file
View File

@@ -0,0 +1,20 @@
# Standard Library
import os
from collections.abc import Iterator
from contextlib import contextmanager
# 3rd-party
from monkeytype.config import DefaultConfig
class MonkeyConfig(DefaultConfig):
@contextmanager
def cli_context(self, command: str) -> Iterator[None]:
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
import django
django.setup()
yield
CONFIG = MonkeyConfig()