mirror of
https://github.com/goauthentik/authentik
synced 2026-05-09 00:22:24 +02:00
* reduce number of db queries * optimize group membership updates too * further optimize include_user=false and also members_by_pk * lint --------- Co-authored-by: Simonyi Gergő <28359278+gergosimonyi@users.noreply.github.com>
450 lines
16 KiB
Python
450 lines
16 KiB
Python
"""Groups API Viewset"""
|
|
|
|
from json import loads
|
|
|
|
from django.db.models import Prefetch
|
|
from django.http import Http404
|
|
from django.utils.translation import gettext as _
|
|
from django_filters.filters import CharFilter, ModelMultipleChoiceFilter
|
|
from django_filters.filterset import FilterSet
|
|
from djangoql.schema import BoolField, StrField
|
|
from drf_spectacular.utils import (
|
|
OpenApiParameter,
|
|
OpenApiResponse,
|
|
extend_schema,
|
|
extend_schema_field,
|
|
)
|
|
from guardian.shortcuts import get_objects_for_user
|
|
from rest_framework.authentication import SessionAuthentication
|
|
from rest_framework.decorators import action
|
|
from rest_framework.fields import CharField, IntegerField, SerializerMethodField
|
|
from rest_framework.permissions import IsAuthenticated
|
|
from rest_framework.relations import ManyRelatedField, PrimaryKeyRelatedField
|
|
from rest_framework.request import Request
|
|
from rest_framework.response import Response
|
|
from rest_framework.serializers import ListSerializer, ValidationError
|
|
from rest_framework.viewsets import ModelViewSet
|
|
|
|
from authentik.api.authentication import TokenAuthentication
|
|
from authentik.api.search.fields import (
|
|
JSONSearchField,
|
|
)
|
|
from authentik.api.validation import validate
|
|
from authentik.core.api.used_by import UsedByMixin
|
|
from authentik.core.api.utils import JSONDictField, ModelSerializer, PassiveSerializer
|
|
from authentik.core.models import Group, User
|
|
from authentik.endpoints.connectors.agent.auth import AgentAuth
|
|
from authentik.rbac.api.roles import RoleSerializer
|
|
from authentik.rbac.decorators import permission_required
|
|
|
|
|
|
class BulkManyRelatedField(ManyRelatedField):
|
|
"""ManyRelatedField that validates all PKs in a single query instead of one per PK."""
|
|
|
|
def to_internal_value(self, data):
|
|
if isinstance(data, str) or not hasattr(data, "__iter__"):
|
|
self.fail("not_a_list", input_type=type(data).__name__)
|
|
if not self.allow_empty and len(data) == 0:
|
|
self.fail("empty")
|
|
|
|
child = self.child_relation
|
|
pk_field = child.pk_field
|
|
# Coerce PKs through pk_field if defined
|
|
pk_map = {}
|
|
for item in data:
|
|
if isinstance(item, bool):
|
|
self.fail("incorrect_type", data_type=type(item).__name__)
|
|
pk = pk_field.to_internal_value(item) if pk_field else item
|
|
pk_map[pk] = item # map coerced PK -> original value for error reporting
|
|
|
|
queryset = child.get_queryset()
|
|
# Use count to validate all PKs exist in a single query
|
|
found_count = queryset.filter(pk__in=pk_map.keys()).count()
|
|
if found_count < len(pk_map):
|
|
# Some PKs not found — fall back to per-PK checks for error reporting.
|
|
# This only runs when there's an actual validation error (rare path).
|
|
for pk, original in pk_map.items():
|
|
if not queryset.filter(pk=pk).exists():
|
|
child.fail("does_not_exist", pk_value=original)
|
|
|
|
# Return raw PKs — Django's M2M set() accepts both objects and PKs,
|
|
# using get_prep_value() for type coercion. This avoids loading all
|
|
# objects into memory and avoids triggering post_init signals.
|
|
return list(pk_map.keys())
|
|
|
|
def to_representation(self, iterable):
|
|
# For non-prefetched querysets, get PKs directly without loading model instances.
|
|
# When prefetched, _result_cache is a list (possibly empty); when not, it's None.
|
|
if hasattr(iterable, "values_list") and getattr(iterable, "_result_cache", None) is None:
|
|
return list(iterable.values_list("pk", flat=True))
|
|
return super().to_representation(iterable)
|
|
|
|
|
|
class BulkPrimaryKeyRelatedField(PrimaryKeyRelatedField):
|
|
"""PrimaryKeyRelatedField that uses bulk validation when many=True."""
|
|
|
|
@classmethod
|
|
def many_init(cls, *args, **kwargs):
|
|
allow_empty = kwargs.pop("allow_empty", None)
|
|
max_length = kwargs.pop("max_length", None)
|
|
min_length = kwargs.pop("min_length", None)
|
|
child_relation = cls(*args, **kwargs)
|
|
list_kwargs = {
|
|
"child_relation": child_relation,
|
|
}
|
|
if allow_empty is not None:
|
|
list_kwargs["allow_empty"] = allow_empty
|
|
if max_length is not None:
|
|
list_kwargs["max_length"] = max_length
|
|
if min_length is not None:
|
|
list_kwargs["min_length"] = min_length
|
|
list_kwargs.update(
|
|
{
|
|
key: value
|
|
for key, value in kwargs.items()
|
|
if key in ("required", "default", "source")
|
|
}
|
|
)
|
|
return BulkManyRelatedField(**list_kwargs)
|
|
|
|
|
|
PARTIAL_USER_SERIALIZER_MODEL_FIELDS = [
|
|
"pk",
|
|
"username",
|
|
"name",
|
|
"is_active",
|
|
"last_login",
|
|
"email",
|
|
"attributes",
|
|
]
|
|
|
|
|
|
class PartialUserSerializer(ModelSerializer):
|
|
"""Partial User Serializer, does not include child relations."""
|
|
|
|
attributes = JSONDictField(required=False)
|
|
uid = CharField(read_only=True)
|
|
|
|
class Meta:
|
|
model = User
|
|
fields = PARTIAL_USER_SERIALIZER_MODEL_FIELDS + ["uid"]
|
|
|
|
|
|
class RelatedGroupSerializer(ModelSerializer):
|
|
"""Stripped down group serializer to show relevant children/parents for groups"""
|
|
|
|
attributes = JSONDictField(required=False)
|
|
|
|
class Meta:
|
|
model = Group
|
|
fields = [
|
|
"pk",
|
|
"name",
|
|
"is_superuser",
|
|
"attributes",
|
|
"group_uuid",
|
|
]
|
|
|
|
|
|
class GroupSerializer(ModelSerializer):
|
|
"""Group Serializer"""
|
|
|
|
attributes = JSONDictField(required=False)
|
|
users = BulkPrimaryKeyRelatedField(queryset=User.objects.all(), many=True, default=list)
|
|
parents = PrimaryKeyRelatedField(queryset=Group.objects.all(), many=True, required=False)
|
|
parents_obj = SerializerMethodField(allow_null=True)
|
|
children_obj = SerializerMethodField(allow_null=True)
|
|
users_obj = SerializerMethodField(allow_null=True)
|
|
roles_obj = ListSerializer(
|
|
child=RoleSerializer(),
|
|
read_only=True,
|
|
source="roles",
|
|
required=False,
|
|
)
|
|
inherited_roles_obj = SerializerMethodField(allow_null=True)
|
|
num_pk = IntegerField(read_only=True)
|
|
|
|
@property
|
|
def _should_include_users(self) -> bool:
|
|
request: Request = self.context.get("request", None)
|
|
if not request:
|
|
return True
|
|
return str(request.query_params.get("include_users", "true")).lower() == "true"
|
|
|
|
@property
|
|
def _should_include_children(self) -> bool:
|
|
request: Request = self.context.get("request", None)
|
|
if not request:
|
|
return True
|
|
return str(request.query_params.get("include_children", "false")).lower() == "true"
|
|
|
|
@property
|
|
def _should_include_parents(self) -> bool:
|
|
request: Request = self.context.get("request", None)
|
|
if not request:
|
|
return True
|
|
return str(request.query_params.get("include_parents", "false")).lower() == "true"
|
|
|
|
@property
|
|
def _should_include_inherited_roles(self) -> bool:
|
|
request: Request = self.context.get("request", None)
|
|
if not request:
|
|
return True
|
|
return str(request.query_params.get("include_inherited_roles", "false")).lower() == "true"
|
|
|
|
@extend_schema_field(PartialUserSerializer(many=True))
|
|
def get_users_obj(self, instance: Group) -> list[PartialUserSerializer] | None:
|
|
if not self._should_include_users:
|
|
return None
|
|
return PartialUserSerializer(instance.users, many=True).data
|
|
|
|
@extend_schema_field(RelatedGroupSerializer(many=True))
|
|
def get_children_obj(self, instance: Group) -> list[RelatedGroupSerializer] | None:
|
|
if not self._should_include_children:
|
|
return None
|
|
return RelatedGroupSerializer(instance.children, many=True).data
|
|
|
|
@extend_schema_field(RelatedGroupSerializer(many=True))
|
|
def get_parents_obj(self, instance: Group) -> list[RelatedGroupSerializer] | None:
|
|
if not self._should_include_parents:
|
|
return None
|
|
return RelatedGroupSerializer(instance.parents, many=True).data
|
|
|
|
@extend_schema_field(RoleSerializer(many=True))
|
|
def get_inherited_roles_obj(self, instance: Group) -> list | None:
|
|
"""Return only inherited roles from ancestor groups (excludes direct roles)"""
|
|
if not self._should_include_inherited_roles:
|
|
return None
|
|
direct_role_pks = instance.roles.values_list("pk", flat=True)
|
|
inherited_roles = instance.all_roles().exclude(pk__in=direct_role_pks)
|
|
return RoleSerializer(inherited_roles, many=True).data
|
|
|
|
def validate_is_superuser(self, superuser: bool):
|
|
"""Ensure that the user creating this group has permissions to set the superuser flag"""
|
|
request: Request = self.context.get("request", None)
|
|
if not request:
|
|
return superuser
|
|
# If we're updating an instance, and the state hasn't changed, we don't need to check perms
|
|
if self.instance and superuser == self.instance.is_superuser:
|
|
return superuser
|
|
user: User = request.user
|
|
perm = (
|
|
"authentik_core.enable_group_superuser"
|
|
if superuser
|
|
else "authentik_core.disable_group_superuser"
|
|
)
|
|
if self.instance or superuser:
|
|
has_perm = user.has_perm(perm) or user.has_perm(perm, self.instance)
|
|
if not has_perm:
|
|
raise ValidationError(
|
|
_(
|
|
(
|
|
"User does not have permission to set "
|
|
"superuser status to {superuser_status}."
|
|
).format_map({"superuser_status": superuser})
|
|
)
|
|
)
|
|
return superuser
|
|
|
|
class Meta:
|
|
model = Group
|
|
fields = [
|
|
"pk",
|
|
"num_pk",
|
|
"name",
|
|
"is_superuser",
|
|
"parents",
|
|
"parents_obj",
|
|
"users",
|
|
"users_obj",
|
|
"attributes",
|
|
"roles",
|
|
"roles_obj",
|
|
"inherited_roles_obj",
|
|
"children",
|
|
"children_obj",
|
|
]
|
|
extra_kwargs = {
|
|
"children": {
|
|
"required": False,
|
|
"default": list,
|
|
},
|
|
"parents": {
|
|
"required": False,
|
|
"default": list,
|
|
},
|
|
}
|
|
|
|
|
|
class GroupFilter(FilterSet):
|
|
"""Filter for groups"""
|
|
|
|
attributes = CharFilter(
|
|
field_name="attributes",
|
|
lookup_expr="",
|
|
label="Attributes",
|
|
method="filter_attributes",
|
|
)
|
|
|
|
members_by_username = ModelMultipleChoiceFilter(
|
|
field_name="users__username",
|
|
to_field_name="username",
|
|
queryset=User.objects.all(),
|
|
)
|
|
members_by_pk = ModelMultipleChoiceFilter(
|
|
field_name="users",
|
|
queryset=User.objects.all(),
|
|
distinct=False,
|
|
)
|
|
|
|
def filter_attributes(self, queryset, name, value):
|
|
"""Filter attributes by query args"""
|
|
try:
|
|
value = loads(value)
|
|
except ValueError:
|
|
raise ValidationError(detail="filter: failed to parse JSON") from None
|
|
if not isinstance(value, dict):
|
|
raise ValidationError(detail="filter: value must be key:value mapping")
|
|
qs = {}
|
|
for key, _value in value.items():
|
|
qs[f"attributes__{key}"] = _value
|
|
try:
|
|
_ = len(queryset.filter(**qs))
|
|
return queryset.filter(**qs)
|
|
except ValueError:
|
|
return queryset
|
|
|
|
class Meta:
|
|
model = Group
|
|
fields = ["name", "is_superuser", "members_by_pk", "attributes", "members_by_username"]
|
|
|
|
|
|
class GroupViewSet(UsedByMixin, ModelViewSet):
|
|
"""Group Viewset"""
|
|
|
|
class UserAccountSerializer(PassiveSerializer):
|
|
"""Account adding/removing operations"""
|
|
|
|
pk = IntegerField(required=True)
|
|
|
|
queryset = Group.objects.none()
|
|
serializer_class = GroupSerializer
|
|
search_fields = ["name", "is_superuser"]
|
|
filterset_class = GroupFilter
|
|
ordering = ["name"]
|
|
authentication_classes = [
|
|
TokenAuthentication,
|
|
SessionAuthentication,
|
|
AgentAuth,
|
|
]
|
|
|
|
def get_ql_fields(self):
|
|
return [
|
|
StrField(Group, "name"),
|
|
BoolField(Group, "is_superuser", nullable=True),
|
|
JSONSearchField(Group, "attributes"),
|
|
]
|
|
|
|
def get_queryset(self):
|
|
# Always prefetch parents and children since their PKs are always serialized
|
|
base_qs = Group.objects.all().prefetch_related("roles", "parents", "children")
|
|
|
|
if self.serializer_class(context={"request": self.request})._should_include_users:
|
|
# Only fetch fields needed by PartialUserSerializer to reduce DB load and instantiation
|
|
# time
|
|
base_qs = base_qs.prefetch_related(
|
|
Prefetch(
|
|
"users",
|
|
queryset=User.objects.all().only(*PARTIAL_USER_SERIALIZER_MODEL_FIELDS),
|
|
)
|
|
)
|
|
# When include_users=false, skip users prefetch entirely.
|
|
# BulkManyRelatedField.to_representation will use values_list to get PKs
|
|
# directly without loading User instances into memory.
|
|
|
|
return base_qs
|
|
|
|
@extend_schema(
|
|
parameters=[
|
|
OpenApiParameter("include_users", bool, default=True),
|
|
OpenApiParameter("include_children", bool, default=False),
|
|
OpenApiParameter("include_parents", bool, default=False),
|
|
OpenApiParameter("include_inherited_roles", bool, default=False),
|
|
]
|
|
)
|
|
def list(self, request, *args, **kwargs):
|
|
return super().list(request, *args, **kwargs)
|
|
|
|
@extend_schema(
|
|
parameters=[
|
|
OpenApiParameter("include_users", bool, default=True),
|
|
OpenApiParameter("include_children", bool, default=False),
|
|
OpenApiParameter("include_parents", bool, default=False),
|
|
OpenApiParameter("include_inherited_roles", bool, default=False),
|
|
]
|
|
)
|
|
def retrieve(self, request, *args, **kwargs):
|
|
return super().retrieve(request, *args, **kwargs)
|
|
|
|
@permission_required("authentik_core.add_user_to_group")
|
|
@extend_schema(
|
|
request=UserAccountSerializer,
|
|
responses={
|
|
204: OpenApiResponse(description="User added"),
|
|
404: OpenApiResponse(description="User not found"),
|
|
},
|
|
)
|
|
@action(
|
|
detail=True,
|
|
methods=["POST"],
|
|
pagination_class=None,
|
|
filter_backends=[],
|
|
permission_classes=[IsAuthenticated],
|
|
)
|
|
@validate(UserAccountSerializer)
|
|
def add_user(self, request: Request, body: UserAccountSerializer, pk: str) -> Response:
|
|
"""Add user to group"""
|
|
group: Group = self.get_object()
|
|
user: User = (
|
|
get_objects_for_user(request.user, "authentik_core.view_user")
|
|
.filter(
|
|
pk=body.validated_data.get("pk"),
|
|
)
|
|
.first()
|
|
)
|
|
if not user:
|
|
raise Http404
|
|
group.users.add(user)
|
|
return Response(status=204)
|
|
|
|
@permission_required("authentik_core.remove_user_from_group")
|
|
@extend_schema(
|
|
request=UserAccountSerializer,
|
|
responses={
|
|
204: OpenApiResponse(description="User removed"),
|
|
404: OpenApiResponse(description="User not found"),
|
|
},
|
|
)
|
|
@action(
|
|
detail=True,
|
|
methods=["POST"],
|
|
pagination_class=None,
|
|
filter_backends=[],
|
|
permission_classes=[IsAuthenticated],
|
|
)
|
|
@validate(UserAccountSerializer)
|
|
def remove_user(self, request: Request, body: UserAccountSerializer, pk: str) -> Response:
|
|
"""Remove user from group"""
|
|
group: Group = self.get_object()
|
|
user: User = (
|
|
get_objects_for_user(request.user, "authentik_core.view_user")
|
|
.filter(
|
|
pk=body.validated_data.get("pk"),
|
|
)
|
|
.first()
|
|
)
|
|
if not user:
|
|
raise Http404
|
|
group.users.remove(user)
|
|
return Response(status=204)
|