mirror of
https://github.com/goauthentik/authentik
synced 2026-05-05 22:52:42 +02:00
Compare commits
9 Commits
logoutresp
...
packages/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99a56a5b9c | ||
|
|
73afaed115 | ||
|
|
8b758402c0 | ||
|
|
050c9c31af | ||
|
|
921269f990 | ||
|
|
87732a413c | ||
|
|
8cfe83bd47 | ||
|
|
1df84d68dd | ||
|
|
7f8527461a |
@@ -246,11 +246,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
]
|
||||
|
||||
def get_ql_fields(self):
|
||||
from djangoql.schema import BoolField, StrField
|
||||
|
||||
from authentik.enterprise.search.fields import (
|
||||
JSONSearchField,
|
||||
)
|
||||
from akql.schema import BoolField, JSONSearchField, StrField
|
||||
|
||||
return [
|
||||
StrField(Group, "name"),
|
||||
|
||||
@@ -504,12 +504,7 @@ class UserViewSet(
|
||||
]
|
||||
|
||||
def get_ql_fields(self):
|
||||
from djangoql.schema import BoolField, StrField
|
||||
|
||||
from authentik.enterprise.search.fields import (
|
||||
ChoiceSearchField,
|
||||
JSONSearchField,
|
||||
)
|
||||
from akql.schema import BoolField, ChoiceSearchField, JSONSearchField, StrField
|
||||
|
||||
return [
|
||||
StrField(User, "username"),
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
"""DjangoQL search"""
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Generator
|
||||
|
||||
from django.db import connection
|
||||
from django.db.models import Model, Q
|
||||
from djangoql.compat import text_type
|
||||
from djangoql.schema import StrField
|
||||
|
||||
|
||||
class JSONSearchField(StrField):
|
||||
"""JSON field for DjangoQL"""
|
||||
|
||||
model: Model
|
||||
|
||||
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
|
||||
# Set this in the constructor to not clobber the type variable
|
||||
self.type = "relation"
|
||||
self.suggest_nested = suggest_nested
|
||||
super().__init__(model, name, nullable)
|
||||
|
||||
def get_lookup(self, path, operator, value):
|
||||
search = "__".join(path)
|
||||
op, invert = self.get_operator(operator)
|
||||
q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
|
||||
return ~q if invert else q
|
||||
|
||||
def json_field_keys(self) -> Generator[tuple[str]]:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
WITH RECURSIVE "{self.name}_keys" AS (
|
||||
SELECT
|
||||
ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
|
||||
"{self.name}" -> jsonb_object_keys("{self.name}") AS value
|
||||
FROM {self.model._meta.db_table}
|
||||
WHERE "{self.name}" IS NOT NULL
|
||||
AND jsonb_typeof("{self.name}") = 'object'
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
ck.key_path_array || jsonb_object_keys(ck.value),
|
||||
ck.value -> jsonb_object_keys(ck.value) AS value
|
||||
FROM "{self.name}_keys" ck
|
||||
WHERE jsonb_typeof(ck.value) = 'object'
|
||||
),
|
||||
|
||||
unique_paths AS (
|
||||
SELECT DISTINCT key_path_array
|
||||
FROM "{self.name}_keys"
|
||||
)
|
||||
|
||||
SELECT key_path_array FROM unique_paths;
|
||||
""" # nosec
|
||||
)
|
||||
return (x[0] for x in cursor.fetchall())
|
||||
|
||||
def get_nested_options(self) -> OrderedDict:
|
||||
"""Get keys of all nested objects to show autocomplete"""
|
||||
if not self.suggest_nested:
|
||||
return OrderedDict()
|
||||
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
|
||||
|
||||
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
|
||||
if not parent_parts:
|
||||
parent_parts = []
|
||||
path = parts.pop(0)
|
||||
parent_parts.append(path)
|
||||
relation_key = "_".join(parent_parts)
|
||||
if len(parts) > 1:
|
||||
out_dict = {
|
||||
relation_key: {
|
||||
parts[0]: {
|
||||
"type": "relation",
|
||||
"relation": f"{relation_key}_{parts[0]}",
|
||||
}
|
||||
}
|
||||
}
|
||||
child_paths = recursive_function(parts.copy(), parent_parts.copy())
|
||||
child_paths.update(out_dict)
|
||||
return child_paths
|
||||
else:
|
||||
return {relation_key: {parts[0]: {}}}
|
||||
|
||||
relation_structure = defaultdict(dict)
|
||||
|
||||
for relations in self.json_field_keys():
|
||||
result = recursive_function([base_model_name] + relations)
|
||||
for relation_key, value in result.items():
|
||||
for sub_relation_key, sub_value in value.items():
|
||||
if not relation_structure[relation_key].get(sub_relation_key, None):
|
||||
relation_structure[relation_key][sub_relation_key] = sub_value
|
||||
else:
|
||||
relation_structure[relation_key][sub_relation_key].update(sub_value)
|
||||
|
||||
final_dict = defaultdict(dict)
|
||||
|
||||
for key, value in relation_structure.items():
|
||||
for sub_key, sub_value in value.items():
|
||||
if not sub_value:
|
||||
final_dict[key][sub_key] = {
|
||||
"type": "str",
|
||||
"nullable": True,
|
||||
}
|
||||
else:
|
||||
final_dict[key][sub_key] = sub_value
|
||||
return OrderedDict(final_dict)
|
||||
|
||||
def relation(self) -> str:
|
||||
return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
|
||||
|
||||
|
||||
class ChoiceSearchField(StrField):
|
||||
def __init__(self, model=None, name=None, nullable=None):
|
||||
super().__init__(model, name, nullable, suggest_options=True)
|
||||
|
||||
def get_options(self, search):
|
||||
result = []
|
||||
choices = self._field_choices()
|
||||
if choices:
|
||||
search = search.lower()
|
||||
for c in choices:
|
||||
choice = text_type(c[0])
|
||||
if search in choice.lower():
|
||||
result.append(choice)
|
||||
return result
|
||||
@@ -1,18 +1,15 @@
|
||||
"""DjangoQL search"""
|
||||
"""QL search"""
|
||||
|
||||
from akql.exceptions import AKQLError
|
||||
from akql.queryset import apply_search
|
||||
from akql.schema import AKQLSchema
|
||||
from django.apps import apps
|
||||
from django.db.models import QuerySet
|
||||
from djangoql.ast import Name
|
||||
from djangoql.exceptions import DjangoQLError
|
||||
from djangoql.queryset import apply_search
|
||||
from djangoql.schema import DjangoQLSchema
|
||||
from drf_spectacular.plumbing import ResolvedComponent, build_object_type
|
||||
from rest_framework.filters import SearchFilter
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.enterprise.search.fields import JSONSearchField
|
||||
|
||||
LOGGER = get_logger()
|
||||
AUTOCOMPLETE_SCHEMA = ResolvedComponent(
|
||||
name="Autocomplete",
|
||||
@@ -22,27 +19,8 @@ AUTOCOMPLETE_SCHEMA = ResolvedComponent(
|
||||
)
|
||||
|
||||
|
||||
class BaseSchema(DjangoQLSchema):
|
||||
"""Base Schema which deals with JSON Fields"""
|
||||
|
||||
def resolve_name(self, name: Name):
|
||||
model = self.model_label(self.current_model)
|
||||
root_field = name.parts[0]
|
||||
field = self.models[model].get(root_field)
|
||||
# If the query goes into a JSON field, return the root
|
||||
# field as the JSON field will do the rest
|
||||
if isinstance(field, JSONSearchField):
|
||||
# This is a workaround; build_filter will remove the right-most
|
||||
# entry in the path as that is intended to be the same as the field
|
||||
# however for JSON that is not the case
|
||||
if name.parts[-1] != root_field:
|
||||
name.parts.append(root_field)
|
||||
return field
|
||||
return super().resolve_name(name)
|
||||
|
||||
|
||||
class QLSearch(SearchFilter):
|
||||
"""rest_framework search filter which uses DjangoQL"""
|
||||
"""rest_framework search filter which uses AKQL"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -59,24 +37,30 @@ class QLSearch(SearchFilter):
|
||||
params = params.replace("\x00", "") # strip null characters
|
||||
return params
|
||||
|
||||
def get_schema(self, request: Request, view) -> BaseSchema:
|
||||
def get_schema(self, request: Request, view) -> AKQLSchema:
|
||||
ql_fields = []
|
||||
if hasattr(view, "get_ql_fields"):
|
||||
ql_fields = view.get_ql_fields()
|
||||
|
||||
class InlineSchema(BaseSchema):
|
||||
class InlineSchema(AKQLSchema):
|
||||
def get_fields(self, model):
|
||||
return ql_fields or []
|
||||
|
||||
return InlineSchema
|
||||
|
||||
def get_search_context(self, request: Request):
|
||||
return {
|
||||
"$ak_user": request.user.pk,
|
||||
}
|
||||
|
||||
def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet:
|
||||
search_query = self.get_search_terms(request)
|
||||
schema = self.get_schema(request, view)
|
||||
if len(search_query) == 0 or not self.enabled:
|
||||
return self._fallback.filter_queryset(request, queryset, view)
|
||||
context = self.get_search_context(request)
|
||||
try:
|
||||
return apply_search(queryset, search_query, schema=schema)
|
||||
except DjangoQLError as exc:
|
||||
return apply_search(queryset, search_query, context=context, schema=schema)
|
||||
except AKQLError as exc:
|
||||
LOGGER.debug("Failed to parse search expression", exc=exc)
|
||||
return self._fallback.filter_queryset(request, queryset, view)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from djangoql.serializers import DjangoQLSchemaSerializer
|
||||
from akql.schema import JSONSearchField
|
||||
from akql.serializers import AKQLSchemaSerializer
|
||||
from drf_spectacular.generators import SchemaGenerator
|
||||
|
||||
from authentik.enterprise.search.fields import JSONSearchField
|
||||
from authentik.enterprise.search.ql import AUTOCOMPLETE_SCHEMA
|
||||
|
||||
|
||||
class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
|
||||
class AKQLSchemaSerializer(AKQLSchemaSerializer):
|
||||
def serialize(self, schema):
|
||||
serialization = super().serialize(schema)
|
||||
for _, fields in schema.models.items():
|
||||
@@ -15,12 +15,6 @@ class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
|
||||
serialization["models"].update(field.get_nested_options())
|
||||
return serialization
|
||||
|
||||
def serialize_field(self, field):
|
||||
result = super().serialize_field(field)
|
||||
if isinstance(field, JSONSearchField):
|
||||
result["relation"] = field.relation()
|
||||
return result
|
||||
|
||||
|
||||
def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs):
|
||||
generator.registry.register_on_missing(AUTOCOMPLETE_SCHEMA)
|
||||
|
||||
@@ -136,9 +136,7 @@ class EventViewSet(
|
||||
filterset_class = EventsFilter
|
||||
|
||||
def get_ql_fields(self):
|
||||
from djangoql.schema import DateTimeField, StrField
|
||||
|
||||
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
|
||||
from akql.schema import ChoiceSearchField, DateTimeField, JSONSearchField, StrField
|
||||
|
||||
return [
|
||||
ChoiceSearchField(Event, "action"),
|
||||
|
||||
21
packages/akql/LICENSE
Normal file
21
packages/akql/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2017 ivelum
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
3
packages/akql/README.md
Normal file
3
packages/akql/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
This is a fork of djangoql.
|
||||
|
||||
https://github.com/ivelum/djangoql
|
||||
1
packages/akql/akql/__init__.py
Normal file
1
packages/akql/akql/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.18.1"
|
||||
91
packages/akql/akql/ast.py
Normal file
91
packages/akql/akql/ast.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from akql.parser import AKQLParser
|
||||
|
||||
|
||||
class Node:
|
||||
def __str__(self):
|
||||
children = []
|
||||
for k, v in self.__dict__.items():
|
||||
vv = v
|
||||
if isinstance(v, list | tuple):
|
||||
vv = "[{}]".format(", ".join([str(v) for v in v if v]))
|
||||
children.append(f"{k}={vv}")
|
||||
return "<{}{}{}>".format(
|
||||
self.__class__.__name__,
|
||||
": " if children else "",
|
||||
", ".join(children),
|
||||
)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
for k, v in self.__dict__.items():
|
||||
if getattr(other, k) != v:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class Expression(Node):
|
||||
def __init__(self, left, operator, right):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
|
||||
class Name(Node):
|
||||
def __init__(self, parts):
|
||||
if isinstance(parts, list):
|
||||
self.parts = parts
|
||||
elif isinstance(parts, tuple):
|
||||
self.parts = list(parts)
|
||||
else:
|
||||
self.parts = [parts]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return ".".join(self.parts)
|
||||
|
||||
|
||||
class Const(Node):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class List(Node):
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return [i.value for i in self.items]
|
||||
|
||||
|
||||
class Operator(Node):
|
||||
def __init__(self, operator):
|
||||
self.operator = operator
|
||||
|
||||
|
||||
class Logical(Operator):
|
||||
pass
|
||||
|
||||
|
||||
class Comparison(Operator):
|
||||
pass
|
||||
|
||||
|
||||
class Variable(Node):
|
||||
|
||||
def __init__(self, name: str, parser: "AKQLParser"):
|
||||
self.name = name
|
||||
self.parser = parser
|
||||
|
||||
@property
|
||||
def value(self) -> Any:
|
||||
return self.parser.context.get(self.name)
|
||||
32
packages/akql/akql/exceptions.py
Normal file
32
packages/akql/akql/exceptions.py
Normal file
@@ -0,0 +1,32 @@
|
||||
class AKQLError(Exception):
|
||||
def __init__(self, message=None, value=None, line=None, column=None):
|
||||
self.value = value
|
||||
self.line = line
|
||||
self.column = column
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self):
|
||||
message = super().__str__()
|
||||
if self.line:
|
||||
position_info = f"Line {self.line}"
|
||||
if self.column:
|
||||
position_info += f", col {self.column}"
|
||||
return f"{position_info}: {message}"
|
||||
else:
|
||||
return message
|
||||
|
||||
|
||||
class AKQLSyntaxError(AKQLError):
|
||||
pass
|
||||
|
||||
|
||||
class AKQLLexerError(AKQLSyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
class AKQLParserError(AKQLSyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
class AKQLSchemaError(AKQLError):
|
||||
pass
|
||||
181
packages/akql/akql/lexer.py
Normal file
181
packages/akql/akql/lexer.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from ply import lex
|
||||
from ply.lex import TOKEN, Lexer, LexToken
|
||||
|
||||
from akql.exceptions import AKQLLexerError
|
||||
|
||||
|
||||
class AKQLLexer:
|
||||
_lexer: Lexer
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._lexer = lex.lex(module=self, **kwargs)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.text = ""
|
||||
self._lexer.lineno = 1
|
||||
return self
|
||||
|
||||
def input(self, s):
|
||||
self.reset()
|
||||
self.text = s
|
||||
self._lexer.input(s)
|
||||
return self
|
||||
|
||||
def token(self):
|
||||
return self._lexer.token()
|
||||
|
||||
# Iterator interface
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
t = self.token()
|
||||
if t is None:
|
||||
raise StopIteration
|
||||
return t
|
||||
|
||||
__next__ = next
|
||||
|
||||
def find_column(self, t: LexToken):
|
||||
"""
|
||||
Returns token position in current text, starting from 1
|
||||
"""
|
||||
cr = max(self.text.rfind(lt, 0, t.lexpos) for lt in self.line_terminators)
|
||||
if cr == -1:
|
||||
return t.lexpos + 1
|
||||
return t.lexpos - cr
|
||||
|
||||
whitespace = " \t\v\f\u00a0"
|
||||
line_terminators = "\n\r\u2028\u2029"
|
||||
|
||||
re_line_terminators = r"\n\r\u2028\u2029"
|
||||
|
||||
re_escaped_char = r"\\[\"\\/bfnrt]"
|
||||
re_escaped_unicode = r"\\u[0-9A-Fa-f]{4}"
|
||||
re_string_char = r"[^\"\\" + re_line_terminators + "]"
|
||||
|
||||
re_int_value = r"(-?0|-?[1-9][0-9]*)"
|
||||
re_fraction_part = r"\.[0-9]+"
|
||||
re_exponent_part = r"[eE][\+-]?[0-9]+"
|
||||
|
||||
tokens = [
|
||||
"COMMA",
|
||||
"OR",
|
||||
"AND",
|
||||
"NOT",
|
||||
"IN",
|
||||
"TRUE",
|
||||
"FALSE",
|
||||
"NONE",
|
||||
"NAME",
|
||||
"STRING_VALUE",
|
||||
"FLOAT_VALUE",
|
||||
"INT_VALUE",
|
||||
"PAREN_L",
|
||||
"PAREN_R",
|
||||
"EQUALS",
|
||||
"NOT_EQUALS",
|
||||
"GREATER",
|
||||
"GREATER_EQUAL",
|
||||
"LESS",
|
||||
"LESS_EQUAL",
|
||||
"CONTAINS",
|
||||
"NOT_CONTAINS",
|
||||
"STARTSWITH",
|
||||
"ENDSWITH",
|
||||
"VARIABLE",
|
||||
]
|
||||
|
||||
t_COMMA = ","
|
||||
t_PAREN_L = r"\("
|
||||
t_PAREN_R = r"\)"
|
||||
t_EQUALS = "="
|
||||
t_NOT_EQUALS = "!="
|
||||
t_GREATER = ">"
|
||||
t_GREATER_EQUAL = ">="
|
||||
t_LESS = "<"
|
||||
t_LESS_EQUAL = "<="
|
||||
t_CONTAINS = "~"
|
||||
t_NOT_CONTAINS = "!~"
|
||||
|
||||
t_NAME = r"[_A-Za-z][_0-9A-Za-z]*(\.[_A-Za-z][_0-9A-Za-z]*)*"
|
||||
|
||||
t_ignore = whitespace
|
||||
|
||||
@TOKEN(r"\$([_A-Za-z\.]+)")
|
||||
def t_VARIABLE(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN(r"\"(" + re_escaped_char + "|" + re_escaped_unicode + "|" + re_string_char + r")*\"")
|
||||
def t_STRING_VALUE(self, t: LexToken):
|
||||
t.value = t.value[1:-1] # cut leading and trailing quotes ""
|
||||
return t
|
||||
|
||||
@TOKEN(
|
||||
re_int_value
|
||||
+ re_fraction_part
|
||||
+ re_exponent_part
|
||||
+ "|"
|
||||
+ re_int_value
|
||||
+ re_fraction_part
|
||||
+ "|"
|
||||
+ re_int_value
|
||||
+ re_exponent_part
|
||||
)
|
||||
def t_FLOAT_VALUE(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN(re_int_value)
|
||||
def t_INT_VALUE(self, t: LexToken):
|
||||
return t
|
||||
|
||||
not_followed_by_name = "(?![_0-9A-Za-z])"
|
||||
|
||||
@TOKEN("or" + not_followed_by_name)
|
||||
def t_OR(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("and" + not_followed_by_name)
|
||||
def t_AND(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("not" + not_followed_by_name)
|
||||
def t_NOT(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("in" + not_followed_by_name)
|
||||
def t_IN(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("startswith" + not_followed_by_name)
|
||||
def t_STARTSWITH(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("endswith" + not_followed_by_name)
|
||||
def t_ENDSWITH(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("True" + not_followed_by_name)
|
||||
def t_TRUE(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("False" + not_followed_by_name)
|
||||
def t_FALSE(self, t: LexToken):
|
||||
return t
|
||||
|
||||
@TOKEN("None" + not_followed_by_name)
|
||||
def t_NONE(self, t: LexToken):
|
||||
return t
|
||||
|
||||
def t_error(self, t: LexToken):
|
||||
raise AKQLLexerError(
|
||||
message=f"Illegal character {repr(t.value[0])}",
|
||||
value=t.value,
|
||||
line=t.lineno,
|
||||
column=self.find_column(t),
|
||||
)
|
||||
|
||||
@TOKEN("[" + re_line_terminators + "]+")
|
||||
def t_newline(self, t: LexToken):
|
||||
t.lexer.lineno += len(t.value)
|
||||
239
packages/akql/akql/parser.py
Normal file
239
packages/akql/akql/parser.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import re
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from ply import yacc
|
||||
from ply.yacc import LRParser, YaccProduction
|
||||
|
||||
from akql.ast import Comparison, Const, Expression, List, Logical, Name, Variable
|
||||
from akql.exceptions import AKQLParserError
|
||||
from akql.lexer import AKQLLexer
|
||||
|
||||
unescape_pattern = re.compile(
|
||||
"(" + AKQLLexer.re_escaped_char + "|" + AKQLLexer.re_escaped_unicode + ")",
|
||||
)
|
||||
|
||||
|
||||
def unescape_repl(m: re.Match[str]) -> str:
|
||||
contents = m.group(1)
|
||||
if len(contents) == 2: # noqa
|
||||
return contents[1]
|
||||
else:
|
||||
return contents.encode("utf8").decode("unicode_escape")
|
||||
|
||||
|
||||
def unescape(value):
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf8")
|
||||
return re.sub(unescape_pattern, unescape_repl, value)
|
||||
|
||||
|
||||
class AKQLParser:
|
||||
yacc: LRParser
|
||||
context: dict[str, Any]
|
||||
|
||||
def __init__(self, debug=False, context: dict[str, Any] | None = None, **kwargs):
|
||||
self.default_lexer = AKQLLexer()
|
||||
self.tokens = self.default_lexer.tokens
|
||||
kwargs["debug"] = debug
|
||||
if "write_tables" not in kwargs:
|
||||
kwargs["write_tables"] = False
|
||||
self.context = context or {}
|
||||
self.yacc = yacc.yacc(module=self, **kwargs)
|
||||
|
||||
def parse(
|
||||
self, input=None, lexer: AKQLLexer | None = None, **kwargs
|
||||
) -> Expression: # noqa: A002
|
||||
lexer = lexer or self.default_lexer
|
||||
return self.yacc.parse(input=input, lexer=lexer, **kwargs)
|
||||
|
||||
start = "expression"
|
||||
|
||||
def p_expression_parens(self, p: YaccProduction):
|
||||
"""
|
||||
expression : PAREN_L expression PAREN_R
|
||||
"""
|
||||
p[0] = p[2]
|
||||
|
||||
def p_expression_logical(self, p: YaccProduction):
|
||||
"""
|
||||
expression : expression logical expression
|
||||
"""
|
||||
p[0] = Expression(left=p[1], operator=p[2], right=p[3])
|
||||
|
||||
def p_expression_comparison(self, p: YaccProduction):
|
||||
"""
|
||||
expression : name comparison_number number
|
||||
| name comparison_string string
|
||||
| name comparison_equality boolean_value
|
||||
| name comparison_equality none
|
||||
| name comparison_in_list const_list_value
|
||||
| name comparison_number variable
|
||||
| name comparison_string variable
|
||||
| name comparison_equality variable
|
||||
| name comparison_in_list variable
|
||||
"""
|
||||
p[0] = Expression(left=p[1], operator=p[2], right=p[3])
|
||||
|
||||
def p_name(self, p: YaccProduction):
|
||||
"""
|
||||
name : NAME
|
||||
"""
|
||||
p[0] = Name(parts=p[1].split("."))
|
||||
|
||||
def p_logical(self, p: YaccProduction):
|
||||
"""
|
||||
logical : AND
|
||||
| OR
|
||||
"""
|
||||
p[0] = Logical(operator=p[1])
|
||||
|
||||
def p_comparison_number(self, p: YaccProduction):
|
||||
"""
|
||||
comparison_number : comparison_equality
|
||||
| comparison_greater_less
|
||||
"""
|
||||
p[0] = p[1]
|
||||
|
||||
def p_comparison_string(self, p: YaccProduction):
|
||||
"""
|
||||
comparison_string : comparison_equality
|
||||
| comparison_greater_less
|
||||
| comparison_string_specific
|
||||
"""
|
||||
p[0] = p[1]
|
||||
|
||||
def p_comparison_equality(self, p: YaccProduction):
|
||||
"""
|
||||
comparison_equality : EQUALS
|
||||
| NOT_EQUALS
|
||||
"""
|
||||
p[0] = Comparison(operator=p[1])
|
||||
|
||||
def p_comparison_greater_less(self, p: YaccProduction):
|
||||
"""
|
||||
comparison_greater_less : GREATER
|
||||
| GREATER_EQUAL
|
||||
| LESS
|
||||
| LESS_EQUAL
|
||||
"""
|
||||
p[0] = Comparison(operator=p[1])
|
||||
|
||||
def p_comparison_string_specific(self, p: YaccProduction):
|
||||
"""
|
||||
comparison_string_specific : CONTAINS
|
||||
| NOT_CONTAINS
|
||||
| STARTSWITH
|
||||
| NOT STARTSWITH
|
||||
| ENDSWITH
|
||||
| NOT ENDSWITH
|
||||
"""
|
||||
p[0] = Comparison(operator=" ".join(p[1:]))
|
||||
|
||||
def p_comparison_in_list(self, p: YaccProduction):
|
||||
"""
|
||||
comparison_in_list : IN
|
||||
| NOT IN
|
||||
"""
|
||||
p[0] = Comparison(operator=" ".join(p[1:]))
|
||||
|
||||
def p_const_value(self, p: YaccProduction):
|
||||
"""
|
||||
const_value : number
|
||||
| string
|
||||
| none
|
||||
| boolean_value
|
||||
"""
|
||||
p[0] = p[1]
|
||||
|
||||
def p_variable(self, p: YaccProduction):
|
||||
"""
|
||||
variable : VARIABLE
|
||||
"""
|
||||
p[0] = Variable(name=unescape(p[1]), parser=self)
|
||||
|
||||
def p_number_int(self, p: YaccProduction):
|
||||
"""
|
||||
number : INT_VALUE
|
||||
"""
|
||||
p[0] = Const(value=int(p[1]))
|
||||
|
||||
def p_number_float(self, p: YaccProduction):
|
||||
"""
|
||||
number : FLOAT_VALUE
|
||||
"""
|
||||
p[0] = Const(value=Decimal(p[1]))
|
||||
|
||||
def p_string(self, p: YaccProduction):
|
||||
"""
|
||||
string : STRING_VALUE
|
||||
"""
|
||||
p[0] = Const(value=unescape(p[1]))
|
||||
|
||||
def p_none(self, p: YaccProduction):
|
||||
"""
|
||||
none : NONE
|
||||
"""
|
||||
p[0] = Const(value=None)
|
||||
|
||||
def p_boolean_value(self, p: YaccProduction):
|
||||
"""
|
||||
boolean_value : true
|
||||
| false
|
||||
"""
|
||||
p[0] = p[1]
|
||||
|
||||
def p_true(self, p: YaccProduction):
|
||||
"""
|
||||
true : TRUE
|
||||
"""
|
||||
p[0] = Const(value=True)
|
||||
|
||||
def p_false(self, p: YaccProduction):
|
||||
"""
|
||||
false : FALSE
|
||||
"""
|
||||
p[0] = Const(value=False)
|
||||
|
||||
def p_const_list_value(self, p: YaccProduction):
|
||||
"""
|
||||
const_list_value : PAREN_L const_value_list PAREN_R
|
||||
"""
|
||||
p[0] = List(items=p[2])
|
||||
|
||||
def p_const_value_list(self, p: YaccProduction):
|
||||
"""
|
||||
const_value_list : const_value_list COMMA const_value
|
||||
"""
|
||||
p[0] = p[1] + [p[3]]
|
||||
|
||||
def p_const_value_list_single(self, p: YaccProduction):
|
||||
"""
|
||||
const_value_list : const_value
|
||||
"""
|
||||
p[0] = [p[1]]
|
||||
|
||||
def p_error(self, token):
|
||||
if token is None:
|
||||
self.raise_syntax_error("Unexpected end of input")
|
||||
else:
|
||||
fragment = str(token.value)
|
||||
self.raise_syntax_error(
|
||||
f"Syntax error at {repr(fragment)}",
|
||||
token=token,
|
||||
)
|
||||
|
||||
def raise_syntax_error(self, message, token=None):
|
||||
if token is None:
|
||||
raise AKQLParserError(message)
|
||||
lexer = token.lexer
|
||||
if callable(getattr(lexer, "find_column", None)):
|
||||
column = lexer.find_column(token)
|
||||
else:
|
||||
column = None
|
||||
raise AKQLParserError(
|
||||
message=message,
|
||||
value=token.value,
|
||||
line=token.lineno,
|
||||
column=column,
|
||||
)
|
||||
1113
packages/akql/akql/parsetab.py
Normal file
1113
packages/akql/akql/parsetab.py
Normal file
File diff suppressed because it is too large
Load Diff
47
packages/akql/akql/queryset.py
Normal file
47
packages/akql/akql/queryset.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from akql.ast import Logical
|
||||
from akql.parser import AKQLParser
|
||||
from akql.schema import AKQLField, AKQLSchema
|
||||
|
||||
|
||||
def build_filter(expr: str, schema_instance: AKQLSchema):
|
||||
if isinstance(expr.operator, Logical):
|
||||
left = build_filter(expr.left, schema_instance)
|
||||
right = build_filter(expr.right, schema_instance)
|
||||
if expr.operator.operator == "or":
|
||||
return left | right
|
||||
else:
|
||||
return left & right
|
||||
|
||||
field = schema_instance.resolve_name(expr.left)
|
||||
if not field:
|
||||
# That must be a reference to a model without specifying a field.
|
||||
# Let's construct an abstract lookup field for it
|
||||
field = AKQLField(
|
||||
name=expr.left.parts[-1],
|
||||
nullable=True,
|
||||
)
|
||||
return field.get_lookup(
|
||||
path=expr.left.parts[:-1],
|
||||
operator=expr.operator.operator,
|
||||
value=expr.right.value,
|
||||
)
|
||||
|
||||
|
||||
def apply_search(
|
||||
queryset: QuerySet,
|
||||
search: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
schema: type[AKQLSchema] | None = None,
|
||||
) -> QuerySet:
|
||||
"""
|
||||
Applies search written in DjangoQL mini-language to given queryset
|
||||
"""
|
||||
ast = AKQLParser(context=context).parse(search)
|
||||
schema = schema or AKQLSchema
|
||||
schema_instance = schema(queryset.model)
|
||||
schema_instance.validate(ast)
|
||||
return queryset.filter(build_filter(ast, schema_instance))
|
||||
618
packages/akql/akql/schema.py
Normal file
618
packages/akql/akql/schema.py
Normal file
@@ -0,0 +1,618 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db import connection, models
|
||||
from django.db.models import ManyToManyRel, ManyToOneRel, Model, Q
|
||||
from django.db.models.fields.related import ForeignObjectRel
|
||||
from django.utils.timezone import get_current_timezone
|
||||
|
||||
from akql.ast import Comparison, Const, List, Logical, Name, Node, Variable
|
||||
from akql.exceptions import AKQLSchemaError
|
||||
|
||||
|
||||
class AKQLField:
|
||||
"""
|
||||
Abstract searchable field
|
||||
"""
|
||||
|
||||
model = None
|
||||
name = None
|
||||
nullable = False
|
||||
suggest_options = False
|
||||
type = "unknown"
|
||||
value_types = []
|
||||
value_types_description = ""
|
||||
|
||||
def __init__(self, model=None, name=None, nullable=None, suggest_options=None):
|
||||
if model is not None:
|
||||
self.model = model
|
||||
if name is not None:
|
||||
self.name = name
|
||||
if nullable is not None:
|
||||
self.nullable = nullable
|
||||
if suggest_options is not None:
|
||||
self.suggest_options = suggest_options
|
||||
|
||||
def _field_choices(self):
|
||||
if self.model:
|
||||
try:
|
||||
return self.model._meta.get_field(self.name).choices
|
||||
except (AttributeError, FieldDoesNotExist):
|
||||
pass
|
||||
return []
|
||||
|
||||
@property
|
||||
def async_options(self):
|
||||
return not self._field_choices()
|
||||
|
||||
def get_options(self, search):
|
||||
"""
|
||||
Override this method to provide custom suggestion options
|
||||
"""
|
||||
result = []
|
||||
choices = self._field_choices()
|
||||
if choices:
|
||||
search = search.lower()
|
||||
for c in choices:
|
||||
choice = str(c[1])
|
||||
if search in choice.lower():
|
||||
result.append(choice)
|
||||
return result
|
||||
|
||||
def get_lookup_name(self):
|
||||
"""
|
||||
Override this method to provide custom lookup name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_lookup_value(self, value):
|
||||
"""
|
||||
Override this method to convert displayed values to lookup values
|
||||
"""
|
||||
choices = self._field_choices()
|
||||
if choices:
|
||||
if isinstance(value, list):
|
||||
return [c[0] for c in choices if c[0] in value or c[1] in value]
|
||||
else:
|
||||
for c in choices:
|
||||
if value in c:
|
||||
return c[0]
|
||||
return value
|
||||
|
||||
def get_operator(self, operator):
|
||||
"""
|
||||
Get a comparison suffix to be used in Django ORM & inversion flag for it
|
||||
|
||||
:param operator: string, DjangoQL comparison operator
|
||||
:return: (suffix, invert) - a tuple with 2 values:
|
||||
suffix - suffix to be used in ORM query, for example '__gt' for '>'
|
||||
invert - boolean, True if this comparison needs to be inverted
|
||||
"""
|
||||
op = {
|
||||
"=": "",
|
||||
">": "__gt",
|
||||
">=": "__gte",
|
||||
"<": "__lt",
|
||||
"<=": "__lte",
|
||||
"~": "__icontains",
|
||||
"in": "__in",
|
||||
"startswith": "__istartswith",
|
||||
"endswith": "__iendswith",
|
||||
}.get(operator)
|
||||
if op is not None:
|
||||
return op, False
|
||||
op = {
|
||||
"!=": "",
|
||||
"!~": "__icontains",
|
||||
"not in": "__in",
|
||||
"not startswith": "__istartswith",
|
||||
"not endswith": "__iendswith",
|
||||
}[operator]
|
||||
return op, True
|
||||
|
||||
def get_lookup(self, path, operator, value):
|
||||
"""
|
||||
Performs a lookup for this field with given path, operator and value.
|
||||
|
||||
Override this if you'd like to implement a fully custom lookup. It
|
||||
should support all comparison operators compatible with the field type.
|
||||
|
||||
:param path: a list of names preceding current lookup. For example,
|
||||
if expression looks like 'author.groups.name = "Foo"' path would
|
||||
be ['author', 'groups']. 'name' is not included, because it's the
|
||||
current field instance itself.
|
||||
:param operator: a string with comparison operator. It could be one of
|
||||
the following: '=', '!=', '>', '>=', '<', '<=', '~', '!~', 'in',
|
||||
'not in'. Depending on the field type, some operators may be
|
||||
excluded. '~' and '!~' can be applied to StrField only and aren't
|
||||
allowed for any other fields. BoolField can't be used with less or
|
||||
greater operators, '>', '>=', '<' and '<=' are excluded for it.
|
||||
:param value: value passed for comparison
|
||||
:return: Q-object
|
||||
"""
|
||||
search = "__".join(path + [self.get_lookup_name()])
|
||||
op, invert = self.get_operator(operator)
|
||||
q = models.Q(**{f"{search}{op}": self.get_lookup_value(value)})
|
||||
return ~q if invert else q
|
||||
|
||||
def validate(self, value):
|
||||
if not self.nullable and value is None:
|
||||
raise AKQLSchemaError(
|
||||
f"Field {self.name} is not nullable, " "can't compare it to None",
|
||||
)
|
||||
if value is not None and type(value) not in self.value_types:
|
||||
if self.nullable:
|
||||
msg = (
|
||||
'Field "{field}" has "nullable {field_type}" type. '
|
||||
"It can be compared to {possible_values} or None, "
|
||||
"but not to {value}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
'Field "{field}" has "{field_type}" type. It can '
|
||||
"be compared to {possible_values}, "
|
||||
"but not to {value}"
|
||||
)
|
||||
raise AKQLSchemaError(
|
||||
msg.format(
|
||||
field=self.name,
|
||||
field_type=self.type,
|
||||
possible_values=self.value_types_description,
|
||||
value=repr(value),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class IntField(AKQLField):
|
||||
type = "int"
|
||||
value_types = [int]
|
||||
value_types_description = "integer numbers"
|
||||
|
||||
def validate(self, value):
|
||||
"""
|
||||
Support enum-like choices defined on an integer field
|
||||
"""
|
||||
return super().validate(self.get_lookup_value(value))
|
||||
|
||||
|
||||
class FloatField(AKQLField):
|
||||
type = "float"
|
||||
value_types = [int, float, Decimal]
|
||||
value_types_description = "floating point numbers"
|
||||
|
||||
|
||||
class StrField(AKQLField):
|
||||
type = "str"
|
||||
value_types = [str]
|
||||
value_types_description = "strings"
|
||||
|
||||
def get_options(self, search):
|
||||
choice_options = super().get_options(search)
|
||||
if choice_options:
|
||||
return choice_options
|
||||
lookup = {}
|
||||
if search:
|
||||
lookup[f"{self.name}__icontains"] = search
|
||||
return (
|
||||
self.model.objects.filter(**lookup)
|
||||
.order_by(self.name)
|
||||
.values_list(self.name, flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
|
||||
class BoolField(AKQLField):
|
||||
type = "bool"
|
||||
value_types = [bool]
|
||||
value_types_description = "True or False"
|
||||
|
||||
|
||||
class DateField(AKQLField):
|
||||
type = "date"
|
||||
value_types = [str]
|
||||
value_types_description = 'dates in "YYYY-MM-DD" format'
|
||||
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
try:
|
||||
self.get_lookup_value(value)
|
||||
except ValueError as exc:
|
||||
raise AKQLSchemaError(
|
||||
f'Field "{self.name}" can be compared to dates in '
|
||||
f'"YYYY-MM-DD" format, but not to {repr(value)}',
|
||||
) from exc
|
||||
|
||||
def get_lookup_value(self, value):
|
||||
if not value:
|
||||
return None
|
||||
return datetime.strptime(value, "%Y-%m-%d").date()
|
||||
|
||||
|
||||
class DateTimeField(AKQLField):
|
||||
type = "datetime"
|
||||
value_types = [str]
|
||||
value_types_description = 'timestamps in "YYYY-MM-DD HH:MM" format'
|
||||
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
try:
|
||||
self.get_lookup_value(value)
|
||||
except ValueError as exc:
|
||||
raise AKQLSchemaError(
|
||||
f'Field "{self.name}" can be compared to timestamps in '
|
||||
f'"YYYY-MM-DD HH:MM" format, but not to {repr(value)}',
|
||||
) from exc
|
||||
|
||||
def get_lookup_value(self, value):
|
||||
if not value:
|
||||
return None
|
||||
for format in [
|
||||
"%Y-%m-%d",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
]:
|
||||
try:
|
||||
dt = datetime.strptime(value, format)
|
||||
if settings.USE_TZ:
|
||||
dt = dt.replace(tzinfo=get_current_timezone())
|
||||
return dt
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_lookup(self, path, operator, value):
|
||||
search = "__".join(path + [self.get_lookup_name()])
|
||||
op, invert = self.get_operator(operator)
|
||||
|
||||
# Add LIKE operator support for datetime fields. For LIKE comparisons
|
||||
# we don't want to convert source value to datetime instance, because
|
||||
# it would effectively kill the idea. What we want is expressions like
|
||||
# 'created ~ "2017-01-30'
|
||||
# to be translated to
|
||||
# 'created LIKE %2017-01-30%',
|
||||
# but it would work only if we pass a string as a parameter. If we pass
|
||||
# a datetime instance, it would add time part in a form of 00:00:00,
|
||||
# and resulting comparison would look like
|
||||
# 'created LIKE %2017-01-30 00:00:00%'
|
||||
# which is not what we want for this case.
|
||||
val = value if operator in ("~", "!~") else self.get_lookup_value(value)
|
||||
|
||||
q = models.Q(**{f"{search}{op}": val})
|
||||
return ~q if invert else q
|
||||
|
||||
|
||||
class RelationField(AKQLField):
|
||||
type = "relation"
|
||||
|
||||
def __init__(self, model, name, related_model, nullable=False, suggest_options=False):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name=name,
|
||||
nullable=nullable,
|
||||
suggest_options=suggest_options,
|
||||
)
|
||||
self.related_model = related_model
|
||||
|
||||
@property
|
||||
def relation(self):
|
||||
return AKQLSchema.model_label(self.related_model)
|
||||
|
||||
|
||||
class JSONSearchField(StrField):
|
||||
"""JSON field for DjangoQL"""
|
||||
|
||||
model: Model
|
||||
|
||||
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
|
||||
# Set this in the constructor to not clobber the type variable
|
||||
self.type = "relation"
|
||||
self.suggest_nested = suggest_nested
|
||||
super().__init__(model, name, nullable)
|
||||
|
||||
def get_lookup(self, path, operator, value):
|
||||
search = "__".join(path)
|
||||
op, invert = self.get_operator(operator)
|
||||
q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
|
||||
return ~q if invert else q
|
||||
|
||||
def json_field_keys(self) -> Generator[tuple[str]]:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
WITH RECURSIVE "{self.name}_keys" AS (
|
||||
SELECT
|
||||
ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
|
||||
"{self.name}" -> jsonb_object_keys("{self.name}") AS value
|
||||
FROM {self.model._meta.db_table}
|
||||
WHERE "{self.name}" IS NOT NULL
|
||||
AND jsonb_typeof("{self.name}") = 'object'
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
ck.key_path_array || jsonb_object_keys(ck.value),
|
||||
ck.value -> jsonb_object_keys(ck.value) AS value
|
||||
FROM "{self.name}_keys" ck
|
||||
WHERE jsonb_typeof(ck.value) = 'object'
|
||||
),
|
||||
|
||||
unique_paths AS (
|
||||
SELECT DISTINCT key_path_array
|
||||
FROM "{self.name}_keys"
|
||||
)
|
||||
|
||||
SELECT key_path_array FROM unique_paths;
|
||||
""" # nosec
|
||||
)
|
||||
return (x[0] for x in cursor.fetchall())
|
||||
|
||||
def get_nested_options(self) -> OrderedDict:
|
||||
"""Get keys of all nested objects to show autocomplete"""
|
||||
if not self.suggest_nested:
|
||||
return OrderedDict()
|
||||
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
|
||||
|
||||
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
|
||||
if not parent_parts:
|
||||
parent_parts = []
|
||||
path = parts.pop(0)
|
||||
parent_parts.append(path)
|
||||
relation_key = "_".join(parent_parts)
|
||||
if len(parts) > 1:
|
||||
out_dict = {
|
||||
relation_key: {
|
||||
parts[0]: {
|
||||
"type": "relation",
|
||||
"relation": f"{relation_key}_{parts[0]}",
|
||||
}
|
||||
}
|
||||
}
|
||||
child_paths = recursive_function(parts.copy(), parent_parts.copy())
|
||||
child_paths.update(out_dict)
|
||||
return child_paths
|
||||
else:
|
||||
return {relation_key: {parts[0]: {}}}
|
||||
|
||||
relation_structure = defaultdict(dict)
|
||||
|
||||
for relations in self.json_field_keys():
|
||||
result = recursive_function([base_model_name] + relations)
|
||||
for relation_key, value in result.items():
|
||||
for sub_relation_key, sub_value in value.items():
|
||||
if not relation_structure[relation_key].get(sub_relation_key, None):
|
||||
relation_structure[relation_key][sub_relation_key] = sub_value
|
||||
else:
|
||||
relation_structure[relation_key][sub_relation_key].update(sub_value)
|
||||
|
||||
final_dict = defaultdict(dict)
|
||||
|
||||
for key, value in relation_structure.items():
|
||||
for sub_key, sub_value in value.items():
|
||||
if not sub_value:
|
||||
final_dict[key][sub_key] = {
|
||||
"type": "str",
|
||||
"nullable": True,
|
||||
}
|
||||
else:
|
||||
final_dict[key][sub_key] = sub_value
|
||||
return OrderedDict(final_dict)
|
||||
|
||||
def relation(self) -> str:
|
||||
return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
|
||||
|
||||
|
||||
class ChoiceSearchField(StrField):
|
||||
def __init__(self, model=None, name=None, nullable=None):
|
||||
super().__init__(model, name, nullable, suggest_options=True)
|
||||
|
||||
def get_options(self, search):
|
||||
result = []
|
||||
choices = self._field_choices()
|
||||
if choices:
|
||||
search = search.lower()
|
||||
for c in choices:
|
||||
choice = str(c[0])
|
||||
if search in choice.lower():
|
||||
result.append(choice)
|
||||
return result
|
||||
|
||||
|
||||
class AKQLSchema:
|
||||
include = () # models to include into introspection
|
||||
exclude = () # models to exclude from introspection
|
||||
suggest_options = None
|
||||
|
||||
def __init__(self, model):
|
||||
if not inspect.isclass(model) or not issubclass(model, models.Model):
|
||||
raise AKQLSchemaError(
|
||||
"Schema must be initialized with a subclass of Django model",
|
||||
)
|
||||
if self.include and self.exclude:
|
||||
raise AKQLSchemaError(
|
||||
"Either include or exclude can be specified, but not both",
|
||||
)
|
||||
if self.excluded(model):
|
||||
raise AKQLSchemaError(
|
||||
f"{model} can't be used with {self.__class__} because it's excluded from it",
|
||||
)
|
||||
self.current_model = model
|
||||
self._models = None
|
||||
if self.suggest_options is None:
|
||||
self.suggest_options = {}
|
||||
|
||||
def excluded(self, model):
|
||||
return model in self.exclude or (self.include and model not in self.include)
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
if not self._models:
|
||||
self._models = self.introspect(
|
||||
model=self.current_model,
|
||||
exclude=tuple(self.model_label(m) for m in self.exclude),
|
||||
)
|
||||
return self._models
|
||||
|
||||
@classmethod
|
||||
def model_label(self, model):
|
||||
return str(model._meta)
|
||||
|
||||
def introspect(self, model, exclude=()):
|
||||
"""
|
||||
Start with given model and recursively walk through its relationships.
|
||||
|
||||
Returns a dict with all model labels and their fields found.
|
||||
"""
|
||||
result = {}
|
||||
open_set = deque([model])
|
||||
closed_set = set(exclude)
|
||||
|
||||
while open_set:
|
||||
model = open_set.popleft()
|
||||
model_label = self.model_label(model)
|
||||
|
||||
if model_label in closed_set:
|
||||
continue
|
||||
|
||||
model_fields = OrderedDict()
|
||||
for field in self.get_fields(model):
|
||||
field_instance = field
|
||||
if not isinstance(field, AKQLField):
|
||||
field_instance = self.get_field_instance(model, field)
|
||||
if not field_instance:
|
||||
continue
|
||||
if isinstance(field_instance, RelationField):
|
||||
open_set.append(field_instance.related_model)
|
||||
model_fields[field_instance.name] = field_instance
|
||||
|
||||
result[model_label] = model_fields
|
||||
closed_set.add(model_label)
|
||||
|
||||
return result
|
||||
|
||||
def get_fields(self, model):
|
||||
"""
|
||||
By default, returns all field names of a given model.
|
||||
|
||||
Override this method to limit field options. You can either return a
|
||||
plain list of field names from it, like ['id', 'name'], or call
|
||||
.super() and exclude unwanted fields from its result.
|
||||
"""
|
||||
return sorted(
|
||||
[f.name for f in model._meta.get_fields() if f.name != "password"],
|
||||
)
|
||||
|
||||
def get_field_instance(self, model, field_name):
|
||||
field = model._meta.get_field(field_name)
|
||||
field_kwargs = {"model": model, "name": field.name}
|
||||
if field.is_relation:
|
||||
if not field.related_model:
|
||||
# GenericForeignKey
|
||||
return
|
||||
if self.excluded(field.related_model):
|
||||
return
|
||||
field_cls = RelationField
|
||||
field_kwargs["related_model"] = field.related_model
|
||||
else:
|
||||
field_cls = self.get_field_cls(field)
|
||||
if isinstance(field, ManyToOneRel | ManyToManyRel | ForeignObjectRel):
|
||||
# Django 1.8 doesn't have .null attribute for these fields
|
||||
field_kwargs["nullable"] = True
|
||||
else:
|
||||
field_kwargs["nullable"] = field.null
|
||||
field_kwargs["suggest_options"] = field.name in self.suggest_options.get(model, [])
|
||||
return field_cls(**field_kwargs)
|
||||
|
||||
def get_field_cls(self, field):
|
||||
str_fields = (
|
||||
models.CharField,
|
||||
models.TextField,
|
||||
models.UUIDField,
|
||||
models.BinaryField,
|
||||
models.GenericIPAddressField,
|
||||
)
|
||||
if isinstance(field, str_fields):
|
||||
return StrField
|
||||
elif isinstance(field, models.AutoField | models.IntegerField):
|
||||
return IntField
|
||||
elif isinstance(field, models.BooleanField | models.NullBooleanField):
|
||||
return BoolField
|
||||
elif isinstance(field, models.DecimalField | models.FloatField):
|
||||
return FloatField
|
||||
elif isinstance(field, models.DateTimeField):
|
||||
return DateTimeField
|
||||
elif isinstance(field, models.DateField):
|
||||
return DateField
|
||||
return AKQLField
|
||||
|
||||
def as_dict(self):
|
||||
from akql.serializers import AKQLSchemaSerializer
|
||||
|
||||
warnings.warn(
|
||||
"DjangoQLSchema.as_dict() is deprecated and will be removed in "
|
||||
"future releases. Please use DjangoQLSchemaSerializer instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return AKQLSchemaSerializer().serialize(self)
|
||||
|
||||
def resolve_name(self, name):
|
||||
assert isinstance(name, Name)
|
||||
model = self.model_label(self.current_model)
|
||||
|
||||
root_field = name.parts[0]
|
||||
field = self.models[model].get(root_field)
|
||||
# If the query goes into a JSON field, return the root
|
||||
# field as the JSON field will do the rest
|
||||
if isinstance(field, JSONSearchField):
|
||||
# This is a workaround; build_filter will remove the right-most
|
||||
# entry in the path as that is intended to be the same as the field
|
||||
# however for JSON that is not the case
|
||||
if name.parts[-1] != root_field:
|
||||
name.parts.append(root_field)
|
||||
return field
|
||||
|
||||
for name_part in name.parts:
|
||||
field = self.models[model].get(name_part)
|
||||
if not field:
|
||||
raise AKQLSchemaError(
|
||||
"Unknown field: {}. Possible choices are: {}".format(
|
||||
name_part,
|
||||
", ".join(sorted(self.models[model].keys())),
|
||||
),
|
||||
)
|
||||
if field.type == "relation":
|
||||
model = field.relation
|
||||
field = None
|
||||
return field
|
||||
|
||||
def validate(self, node):
|
||||
"""
|
||||
Validate DjangoQL AST tree vs. current schema
|
||||
"""
|
||||
assert isinstance(node, Node)
|
||||
if isinstance(node.operator, Logical):
|
||||
self.validate(node.left)
|
||||
self.validate(node.right)
|
||||
return
|
||||
assert isinstance(node.left, Name)
|
||||
assert isinstance(node.operator, Comparison)
|
||||
assert isinstance(node.right, Const | List | Variable)
|
||||
|
||||
# Check that field and value types are compatible
|
||||
field = self.resolve_name(node.left)
|
||||
value = node.right.value
|
||||
if field is None:
|
||||
if value is not None:
|
||||
raise AKQLSchemaError(
|
||||
f"Related model {node.left.value} can be compared to None only, but not to "
|
||||
f"{type(value).__name__}",
|
||||
)
|
||||
else:
|
||||
values = value if isinstance(node.right, List) else [value]
|
||||
for v in values:
|
||||
field.validate(v)
|
||||
31
packages/akql/akql/serializers.py
Normal file
31
packages/akql/akql/serializers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from akql.schema import JSONSearchField, RelationField
|
||||
|
||||
|
||||
class AKQLSchemaSerializer:
|
||||
def serialize(self, schema):
|
||||
models = {}
|
||||
for model_label, fields in schema.models.items():
|
||||
models[model_label] = OrderedDict(
|
||||
[(name, self.serialize_field(f)) for name, f in fields.items()],
|
||||
)
|
||||
return {
|
||||
"current_model": schema.model_label(schema.current_model),
|
||||
"models": models,
|
||||
}
|
||||
|
||||
def serialize_field(self, field):
|
||||
result = {
|
||||
"type": field.type,
|
||||
"nullable": field.nullable,
|
||||
"options": self.serialize_field_options(field),
|
||||
}
|
||||
if isinstance(field, RelationField):
|
||||
result["relation"] = field.relation
|
||||
if isinstance(field, JSONSearchField):
|
||||
result["relation"] = field.relation()
|
||||
return result
|
||||
|
||||
def serialize_field_options(self, field):
|
||||
return list(field.get_options("")) if field.suggest_options else None
|
||||
0
packages/akql/akql/tests/__init__.py
Normal file
0
packages/akql/akql/tests/__init__.py
Normal file
16
packages/akql/akql/tests/test_filter.py
Normal file
16
packages/akql/akql/tests/test_filter.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from akql.queryset import apply_search
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import Notification
|
||||
|
||||
|
||||
class TestFilter(TestCase):
|
||||
|
||||
def test_filter(self):
|
||||
user = create_test_user()
|
||||
notif = Notification.objects.create(user=user)
|
||||
qs = apply_search(
|
||||
Notification.objects.all(), "user.id = $current_user", {"$current_user": user.pk}
|
||||
)
|
||||
self.assertEqual(qs.first(), notif)
|
||||
18
packages/akql/akql/tests/test_lexer.py
Normal file
18
packages/akql/akql/tests/test_lexer.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from akql.lexer import AKQLLexer
|
||||
|
||||
|
||||
class TestLexer(TestCase):
|
||||
|
||||
def test_lexer_simple(self):
|
||||
lexer = AKQLLexer().input('foo = "bar"')
|
||||
tokens = list(str(t) for t in lexer)
|
||||
self.assertEqual(
|
||||
tokens,
|
||||
[
|
||||
"LexToken(NAME,'foo',1,0)",
|
||||
"LexToken(EQUALS,'=',1,4)",
|
||||
"LexToken(STRING_VALUE,'bar',1,6)",
|
||||
],
|
||||
)
|
||||
41
packages/akql/akql/tests/test_parser.py
Normal file
41
packages/akql/akql/tests/test_parser.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from akql.ast import Comparison, Const, Expression, Name, Variable
|
||||
from akql.parser import AKQLParser
|
||||
|
||||
|
||||
class TestParser(TestCase):
|
||||
|
||||
def test_parser_simple(self):
|
||||
ast = AKQLParser().parse('foo = "bar"')
|
||||
self.assertEqual(
|
||||
ast,
|
||||
Expression(
|
||||
left=Name(parts=["foo"]),
|
||||
operator=Comparison(operator="="),
|
||||
right=Const(value="bar"),
|
||||
),
|
||||
)
|
||||
|
||||
def test_parser_not_startswith(self):
|
||||
ast = AKQLParser().parse('foo not startswith "bar"')
|
||||
self.assertEqual(
|
||||
ast,
|
||||
Expression(
|
||||
left=Name(parts=["foo"]),
|
||||
operator=Comparison(operator="not startswith"),
|
||||
right=Const(value="bar"),
|
||||
),
|
||||
)
|
||||
|
||||
def test_parser_variable(self):
|
||||
parser = AKQLParser()
|
||||
ast = parser.parse("foo = $bar")
|
||||
self.assertEqual(
|
||||
ast,
|
||||
Expression(
|
||||
left=Name(parts=["foo"]),
|
||||
operator=Comparison(operator="="),
|
||||
right=Variable(name="$bar", parser=parser),
|
||||
),
|
||||
)
|
||||
51
packages/akql/pyproject.toml
Normal file
51
packages/akql/pyproject.toml
Normal file
@@ -0,0 +1,51 @@
|
||||
[project]
|
||||
name = "akql"
|
||||
version = "3.2.0"
|
||||
description = "Model and object permissions for Django"
|
||||
requires-python = ">=3.9,<3.14"
|
||||
readme = "README.md"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "Authentik Security Inc.", email = "hello@goauthentik.io" },
|
||||
{ name = "Denis Stebunov", email = "support@ivelum.com" },
|
||||
]
|
||||
keywords = ["django", "permissions", "authorization", "object", "row", "level"]
|
||||
|
||||
classifiers = [
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
'Natural Language :: English',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: Python :: 3.11',
|
||||
'Programming Language :: Python :: 3.12',
|
||||
'Programming Language :: Python :: 3.13',
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"ply>=3.8",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/goauthentik/authentik/tree/main/packages/akql"
|
||||
Documentation = "https://github.com/goauthentik/authentik/tree/main/packages/akql"
|
||||
Repository = "https://github.com/goauthentik/authentik/tree/main/packages/akql"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [
|
||||
"akql",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages]
|
||||
find = {}
|
||||
@@ -6,6 +6,7 @@ authors = [{ name = "authentik Team", email = "hello@goauthentik.io" }]
|
||||
requires-python = "==3.13.*"
|
||||
dependencies = [
|
||||
"ak-guardian==3.2.0",
|
||||
"akql",
|
||||
"argon2-cffi==25.1.0",
|
||||
"channels==4.3.1",
|
||||
"cryptography==45.0.5",
|
||||
@@ -26,7 +27,6 @@ dependencies = [
|
||||
"django-prometheus==2.4.1",
|
||||
"django-storages[s3]==1.14.6",
|
||||
"django-tenants==3.9.0",
|
||||
"djangoql==0.18.1",
|
||||
"djangorestframework==3.16.1",
|
||||
"docker==7.1.0",
|
||||
"drf-orjson-renderer==1.7.3",
|
||||
@@ -121,6 +121,7 @@ no-binary-package = [
|
||||
|
||||
[tool.uv.sources]
|
||||
ak-guardian = { workspace = true }
|
||||
akql = { workspace = true }
|
||||
django-channels-postgres = { workspace = true }
|
||||
django-dramatiq-postgres = { workspace = true }
|
||||
django-postgres-cache = { workspace = true }
|
||||
@@ -129,6 +130,7 @@ opencontainers = { git = "https://github.com/vsoch/oci-python", rev = "ceb4fcc09
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"packages/ak-guardian",
|
||||
"packages/akql",
|
||||
"packages/django-channels-postgres",
|
||||
"packages/django-dramatiq-postgres",
|
||||
"packages/django-postgres-cache",
|
||||
|
||||
27
uv.lock
generated
27
uv.lock
generated
@@ -5,6 +5,7 @@ requires-python = "==3.13.*"
|
||||
[manifest]
|
||||
members = [
|
||||
"ak-guardian",
|
||||
"akql",
|
||||
"authentik",
|
||||
"django-channels-postgres",
|
||||
"django-dramatiq-postgres",
|
||||
@@ -93,6 +94,17 @@ requires-dist = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.15'", specifier = ">=4.12.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "akql"
|
||||
version = "3.2.0"
|
||||
source = { editable = "packages/akql" }
|
||||
dependencies = [
|
||||
{ name = "ply" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "ply", specifier = ">=3.8" }]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
version = "0.7.0"
|
||||
@@ -189,6 +201,7 @@ version = "2026.2.0rc1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "ak-guardian" },
|
||||
{ name = "akql" },
|
||||
{ name = "argon2-cffi" },
|
||||
{ name = "channels" },
|
||||
{ name = "cryptography" },
|
||||
@@ -209,7 +222,6 @@ dependencies = [
|
||||
{ name = "django-prometheus" },
|
||||
{ name = "django-storages", extra = ["s3"] },
|
||||
{ name = "django-tenants" },
|
||||
{ name = "djangoql" },
|
||||
{ name = "djangorestframework" },
|
||||
{ name = "docker" },
|
||||
{ name = "drf-orjson-renderer" },
|
||||
@@ -293,6 +305,7 @@ dev = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "ak-guardian", editable = "packages/ak-guardian" },
|
||||
{ name = "akql", editable = "packages/akql" },
|
||||
{ name = "argon2-cffi", specifier = "==25.1.0" },
|
||||
{ name = "channels", specifier = "==4.3.1" },
|
||||
{ name = "cryptography", specifier = "==45.0.5" },
|
||||
@@ -313,7 +326,6 @@ requires-dist = [
|
||||
{ name = "django-prometheus", specifier = "==2.4.1" },
|
||||
{ name = "django-storages", extras = ["s3"], specifier = "==1.14.6" },
|
||||
{ name = "django-tenants", specifier = "==3.9.0" },
|
||||
{ name = "djangoql", specifier = "==0.18.1" },
|
||||
{ name = "djangorestframework", specifier = "==3.16.1" },
|
||||
{ name = "docker", specifier = "==7.1.0" },
|
||||
{ name = "drf-orjson-renderer", specifier = "==1.7.3" },
|
||||
@@ -1252,17 +1264,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/57/918cfca627fcdc3441981dddc72a22be02e57abdb5391eb7339ea77a5ef4/django_tenants-3.9.0-py3-none-any.whl", hash = "sha256:14421088a4336444e2c4af54f21a6af2e57e53dcf95ba5d19b5fa17142cb460b", size = 215955, upload-time = "2025-09-06T21:46:05.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "djangoql"
|
||||
version = "0.18.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "ply" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/0a/83cdb7b9d3b854b98941363153945f6c051b3bc50cd61108a85677c98c3a/djangoql-0.18.1-py2.py3-none-any.whl", hash = "sha256:51b3085a805627ebb43cfd0aa861137cdf8f69cc3c9244699718fe04a6c8e26d", size = 218209, upload-time = "2024-01-08T14:10:47.915Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "djangorestframework"
|
||||
version = "3.16.1"
|
||||
|
||||
Reference in New Issue
Block a user